forked from mindspore-Ecosystem/mindspore
1. support get_all_edges, get_nodes_from_edge, get_sampled_neighbors, get_neg_sampled_neighbors and graph_info API
2. mod cora and citeseer conversion script
This commit is contained in:
parent
444d9484d7
commit
3ece8dd090
|
@ -24,9 +24,6 @@ This example provides an efficient way to generate MindRecord. Users only need t
|
|||
|
||||
1. Download and prepare the Cora dataset as required.
|
||||
|
||||
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
|
||||
|
||||
|
||||
2. Edit write_cora.sh and modify the parameters
|
||||
```
|
||||
--mindrecord_file: output MindRecord file.
|
||||
|
|
|
@ -15,29 +15,26 @@
|
|||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import csv
|
||||
import os
|
||||
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
CITESEER_CONTENT_FILE = args[0]
|
||||
CITESEER_CITES_FILE = args[1]
|
||||
CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord"
|
||||
CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord"
|
||||
|
||||
node_id_map = {}
|
||||
CITESEER_PATH = args[0]
|
||||
dataset_str = 'citeseer'
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
|
||||
node_profile = (2, ["float32", "int32"], [[-1], [-1]])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
node_ids = []
|
||||
|
||||
|
||||
def _normalize_citeseer_features(features):
|
||||
features = np.array(features)
|
||||
row_sum = np.array(features.sum(1))
|
||||
r_inv = np.power(row_sum * 1.0, -1).flatten()
|
||||
r_inv[np.isinf(r_inv)] = 0.
|
||||
|
@ -46,6 +43,14 @@ def _normalize_citeseer_features(features):
|
|||
return features
|
||||
|
||||
|
||||
def _parse_index_file(filename):
|
||||
"""Parse index file."""
|
||||
index = []
|
||||
for line in open(filename):
|
||||
index.append(int(line.strip()))
|
||||
return index
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
@ -54,29 +59,46 @@ def yield_nodes(task_id=0):
|
|||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
label_types = {}
|
||||
label_size = 0
|
||||
node_num = 0
|
||||
with open(CITESEER_CONTENT_FILE) as content_file:
|
||||
content_reader = csv.reader(content_file, delimiter='\t')
|
||||
line_count = 0
|
||||
for row in content_reader:
|
||||
if not row[-1] in label_types:
|
||||
label_types[row[-1]] = label_size
|
||||
label_size += 1
|
||||
if not row[0] in node_id_map:
|
||||
node_id_map[row[0]] = node_num
|
||||
node_num += 1
|
||||
raw_features = [[int(x) for x in row[1:-1]]]
|
||||
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features),
|
||||
'feature_2': [label_types[row[-1]]]}
|
||||
yield node
|
||||
line_count += 1
|
||||
names = ['x', 'y', 'tx', 'ty', 'allx', 'ally']
|
||||
objects = []
|
||||
for name in names:
|
||||
with open("{}/ind.{}.{}".format(CITESEER_PATH, dataset_str, name), 'rb') as f:
|
||||
objects.append(pkl.load(f, encoding='latin1'))
|
||||
x, y, tx, ty, allx, ally = tuple(objects)
|
||||
test_idx_reorder = _parse_index_file(
|
||||
"{}/ind.{}.test.index".format(CITESEER_PATH, dataset_str))
|
||||
test_idx_range = np.sort(test_idx_reorder)
|
||||
|
||||
tx = _normalize_citeseer_features(tx)
|
||||
allx = _normalize_citeseer_features(allx)
|
||||
|
||||
# Fix citeseer dataset (there are some isolated nodes in the graph)
|
||||
# Find isolated nodes, add them as zero-vecs into the right position
|
||||
test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
|
||||
tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
|
||||
tx_extended[test_idx_range-min(test_idx_range), :] = tx
|
||||
tx = tx_extended
|
||||
ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
|
||||
ty_extended[test_idx_range-min(test_idx_range), :] = ty
|
||||
ty = ty_extended
|
||||
|
||||
features = sp.vstack((allx, tx)).tolil()
|
||||
features[test_idx_reorder, :] = features[test_idx_range, :]
|
||||
features = features.A
|
||||
|
||||
labels = np.vstack((ally, ty))
|
||||
labels[test_idx_reorder, :] = labels[test_idx_range, :]
|
||||
|
||||
line_count = 0
|
||||
for i, label in enumerate(labels):
|
||||
if not 1 in label.tolist():
|
||||
continue
|
||||
node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
|
||||
'feature_2': label.tolist().index(1)}
|
||||
line_count += 1
|
||||
node_ids.append(i)
|
||||
yield node
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
# print('label types {}.'.format(label_types))
|
||||
with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f:
|
||||
for k in label_types:
|
||||
print(k + ',' + str(label_types[k]), file=f)
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
|
@ -87,23 +109,20 @@ def yield_edges(task_id=0):
|
|||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
# print(map_string_int)
|
||||
with open(CITESEER_CITES_FILE) as cites_file:
|
||||
cites_reader = csv.reader(cites_file, delimiter='\t')
|
||||
with open("{}/ind.{}.graph".format(CITESEER_PATH, dataset_str), 'rb') as f:
|
||||
graph = pkl.load(f, encoding='latin1')
|
||||
line_count = 0
|
||||
for row in cites_reader:
|
||||
if not row[0] in node_id_map:
|
||||
print('Source node {} does not exist.'.format(row[0]))
|
||||
continue
|
||||
if not row[1] in node_id_map:
|
||||
print('Destination node {} does not exist.'.format(row[1]))
|
||||
continue
|
||||
line_count += 1
|
||||
edge = {'id': line_count,
|
||||
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
|
||||
yield edge
|
||||
|
||||
with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f:
|
||||
for k in node_id_map:
|
||||
print(k + ',' + str(node_id_map[k]), file=f)
|
||||
for i in graph:
|
||||
for dst_id in graph[i]:
|
||||
if not i in node_ids:
|
||||
print('Source node {} does not exist.'.format(i))
|
||||
continue
|
||||
if not dst_id in node_ids:
|
||||
print('Destination node {} does not exist.'.format(
|
||||
dst_id))
|
||||
continue
|
||||
edge = {'id': line_count,
|
||||
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
||||
line_count += 1
|
||||
yield edge
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
||||
|
|
|
@ -15,29 +15,24 @@
|
|||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import csv
|
||||
import os
|
||||
|
||||
import pickle as pkl
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
CORA_CONTENT_FILE = args[0]
|
||||
CORA_CITES_FILE = args[1]
|
||||
CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord"
|
||||
CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord"
|
||||
|
||||
node_id_map = {}
|
||||
CORA_PATH = args[0]
|
||||
dataset_str = 'cora'
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
|
||||
node_profile = (2, ["float32", "int32"], [[-1], [-1]])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def _normalize_cora_features(features):
|
||||
features = np.array(features)
|
||||
row_sum = np.array(features.sum(1))
|
||||
r_inv = np.power(row_sum * 1.0, -1).flatten()
|
||||
r_inv[np.isinf(r_inv)] = 0.
|
||||
|
@ -46,6 +41,14 @@ def _normalize_cora_features(features):
|
|||
return features
|
||||
|
||||
|
||||
def _parse_index_file(filename):
|
||||
"""Parse index file."""
|
||||
index = []
|
||||
for line in open(filename):
|
||||
index.append(int(line.strip()))
|
||||
return index
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
@ -54,32 +57,32 @@ def yield_nodes(task_id=0):
|
|||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
label_types = {}
|
||||
label_size = 0
|
||||
node_num = 0
|
||||
with open(CORA_CONTENT_FILE) as content_file:
|
||||
content_reader = csv.reader(content_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in content_reader:
|
||||
if line_count == 0:
|
||||
line_count += 1
|
||||
continue
|
||||
if not row[0] in node_id_map:
|
||||
node_id_map[row[0]] = node_num
|
||||
node_num += 1
|
||||
if not row[-1] in label_types:
|
||||
label_types[row[-1]] = label_size
|
||||
label_size += 1
|
||||
raw_features = [[int(x) for x in row[1:-1]]]
|
||||
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features),
|
||||
'feature_2': [label_types[row[-1]]]}
|
||||
yield node
|
||||
line_count += 1
|
||||
|
||||
names = ['tx', 'ty', 'allx', 'ally']
|
||||
objects = []
|
||||
for name in names:
|
||||
with open("{}/ind.{}.{}".format(CORA_PATH, dataset_str, name), 'rb') as f:
|
||||
objects.append(pkl.load(f, encoding='latin1'))
|
||||
tx, ty, allx, ally = tuple(objects)
|
||||
test_idx_reorder = _parse_index_file(
|
||||
"{}/ind.{}.test.index".format(CORA_PATH, dataset_str))
|
||||
test_idx_range = np.sort(test_idx_reorder)
|
||||
|
||||
features = sp.vstack((allx, tx)).tolil()
|
||||
features[test_idx_reorder, :] = features[test_idx_range, :]
|
||||
features = _normalize_cora_features(features)
|
||||
features = features.A
|
||||
|
||||
labels = np.vstack((ally, ty))
|
||||
labels[test_idx_reorder, :] = labels[test_idx_range, :]
|
||||
|
||||
line_count = 0
|
||||
for i, label in enumerate(labels):
|
||||
node = {'id': i, 'type': 0, 'feature_1': features[i].tolist(),
|
||||
'feature_2': label.tolist().index(1)}
|
||||
line_count += 1
|
||||
yield node
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
print('label types {}.'.format(label_types))
|
||||
with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f:
|
||||
for k in label_types:
|
||||
print(k + ',' + str(label_types[k]), file=f)
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
|
@ -90,24 +93,13 @@ def yield_edges(task_id=0):
|
|||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
with open(CORA_CITES_FILE) as cites_file:
|
||||
cites_reader = csv.reader(cites_file, delimiter=',')
|
||||
with open("{}/ind.{}.graph".format(CORA_PATH, dataset_str), 'rb') as f:
|
||||
graph = pkl.load(f, encoding='latin1')
|
||||
line_count = 0
|
||||
for row in cites_reader:
|
||||
if line_count == 0:
|
||||
for i in graph:
|
||||
for dst_id in graph[i]:
|
||||
edge = {'id': line_count,
|
||||
'src_id': i, 'dst_id': dst_id, 'type': 0}
|
||||
line_count += 1
|
||||
continue
|
||||
if not row[0] in node_id_map:
|
||||
print('Source node {} does not exist.'.format(row[0]))
|
||||
continue
|
||||
if not row[1] in node_id_map:
|
||||
print('Destination node {} does not exist.'.format(row[1]))
|
||||
continue
|
||||
edge = {'id': line_count,
|
||||
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
|
||||
yield edge
|
||||
line_count += 1
|
||||
yield edge
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
||||
with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f:
|
||||
for k in node_id_map:
|
||||
print(k + ',' + str(node_id_map[k]), file=f)
|
||||
|
|
|
@ -9,4 +9,4 @@ python writer.py --mindrecord_script citeseer \
|
|||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$SRC_PATH/citeseer.content:$SRC_PATH/citeseer.cites"
|
||||
--graph_api_args "$SRC_PATH"
|
||||
|
|
|
@ -9,4 +9,4 @@ python writer.py --mindrecord_script cora \
|
|||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "$SRC_PATH/cora_content.csv:$SRC_PATH/cora_cites.csv"
|
||||
--graph_api_args "$SRC_PATH"
|
||||
|
|
|
@ -527,10 +527,22 @@ void bindGraphData(py::module *m) {
|
|||
THROW_IF_ERROR(g_out->Init());
|
||||
return g_out;
|
||||
}))
|
||||
.def("get_nodes",
|
||||
[](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) {
|
||||
.def("get_all_nodes",
|
||||
[](gnn::Graph &g, gnn::NodeType node_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out));
|
||||
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_edges",
|
||||
[](gnn::Graph &g, gnn::EdgeType edge_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_nodes_from_edges",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_neighbors",
|
||||
|
@ -539,12 +551,31 @@ void bindGraphData(py::module *m) {
|
|||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_sampled_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
||||
std::vector<gnn::NodeType> neighbor_types) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_neg_sampled_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
|
||||
gnn::NodeType neg_neighbor_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_node_feature",
|
||||
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
||||
TensorRow out;
|
||||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||
return out;
|
||||
});
|
||||
})
|
||||
.def("graph_info", [](gnn::Graph &g) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(g.GraphInfo(&out));
|
||||
return out;
|
||||
});
|
||||
}
|
||||
|
||||
// This is where we externalize the C logic as python modules
|
||||
|
|
|
@ -17,29 +17,30 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
Graph::Graph(std::string dataset_file, int32_t num_workers) : dataset_file_(dataset_file), num_workers_(num_workers) {
|
||||
Graph::Graph(std::string dataset_file, int32_t num_workers)
|
||||
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) {
|
||||
rnd_.seed(GetSeed());
|
||||
MS_LOG(INFO) << "num_workers:" << num_workers;
|
||||
}
|
||||
|
||||
Status Graph::GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out) {
|
||||
Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
|
||||
auto itr = node_type_map_.find(node_type);
|
||||
if (itr == node_type_map_.end()) {
|
||||
std::string err_msg = "Invalid node type:" + std::to_string(node_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
if (node_num == -1) {
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out));
|
||||
} else {
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -59,9 +60,9 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data
|
|||
RETURN_IF_NOT_OK(Tensor::CreateTensor(
|
||||
&tensor, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(m), static_cast<dsize_t>(n)}), type, nullptr));
|
||||
T *ptr = reinterpret_cast<T *>(tensor->GetMutableBuffer());
|
||||
for (auto id_m : data) {
|
||||
for (const auto &id_m : data) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size");
|
||||
for (auto id_n : id_m) {
|
||||
for (const auto &id_n : id_m) {
|
||||
*ptr = id_n;
|
||||
ptr++;
|
||||
}
|
||||
|
@ -89,7 +90,38 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out) { return Status::OK(); }
|
||||
Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
|
||||
auto itr = edge_type_map_.find(edge_type);
|
||||
if (itr == edge_type_map_.end()) {
|
||||
std::string err_msg = "Invalid edge type:" + std::to_string(edge_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>({itr->second}, DataType(DataType::DE_INT32), out));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
|
||||
if (edge_list.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Input edge_list is empty");
|
||||
}
|
||||
|
||||
std::vector<std::vector<NodeIdType>> node_list;
|
||||
node_list.reserve(edge_list.size());
|
||||
for (const auto &edge_id : edge_list) {
|
||||
auto itr = edge_id_map_.find(edge_id);
|
||||
if (itr == edge_id_map_.end()) {
|
||||
std::string err_msg = "Invalid edge id:" + std::to_string(edge_id);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> nodes;
|
||||
RETURN_IF_NOT_OK(itr->second->GetNode(&nodes));
|
||||
node_list.push_back({nodes.first->id(), nodes.second->id()});
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(node_list, DataType(DataType::DE_INT32), out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
|
@ -105,14 +137,10 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
|
|||
size_t max_neighbor_num = 0;
|
||||
neighbors.resize(node_list.size());
|
||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||
auto itr = node_id_map_.find(node_list[i]);
|
||||
if (itr != node_id_map_.end()) {
|
||||
RETURN_IF_NOT_OK(itr->second->GetNeighbors(neighbor_type, -1, &neighbors[i]));
|
||||
max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size();
|
||||
} else {
|
||||
std::string err_msg = "Invalid node id:" + std::to_string(node_list[i]);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node));
|
||||
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i]));
|
||||
max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size();
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId));
|
||||
|
@ -121,13 +149,94 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||
Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
|
||||
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
|
||||
std::vector<std::vector<NodeIdType>> neighbors_vec(node_list.size());
|
||||
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
||||
neighbors_vec[node_idx].emplace_back(node_list[node_idx]);
|
||||
std::vector<NodeIdType> input_list = {node_list[node_idx]};
|
||||
for (size_t i = 0; i < neighbor_nums.size(); ++i) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
neighbors.reserve(input_list.size() * neighbor_nums[i]);
|
||||
for (const auto &node_id : input_list) {
|
||||
if (node_id == kDefaultNodeId) {
|
||||
for (int32_t j = 0; j < neighbor_nums[i]; ++j) {
|
||||
neighbors.emplace_back(kDefaultNodeId);
|
||||
}
|
||||
} else {
|
||||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node));
|
||||
std::vector<NodeIdType> out;
|
||||
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out));
|
||||
neighbors.insert(neighbors.end(), out.begin(), out.end());
|
||||
}
|
||||
}
|
||||
neighbors_vec[node_idx].insert(neighbors_vec[node_idx].end(), neighbors.begin(), neighbors.end());
|
||||
input_list = std::move(neighbors);
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
||||
Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data,
|
||||
int32_t samples_num, std::vector<NodeIdType> *out_samples) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
|
||||
std::vector<NodeIdType> shuffled_id(data.size());
|
||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
for (const auto &index : shuffled_id) {
|
||||
if (exclude_data.find(data[index]) != exclude_data.end()) {
|
||||
continue;
|
||||
}
|
||||
out_samples->emplace_back(data[index]);
|
||||
if (out_samples->size() >= samples_num) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
std::vector<std::vector<NodeIdType>> neighbors_vec;
|
||||
neighbors_vec.resize(node_list.size());
|
||||
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
||||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[node_idx], &node));
|
||||
std::vector<NodeIdType> neighbors;
|
||||
RETURN_IF_NOT_OK(node->GetAllNeighbors(neg_neighbor_type, &neighbors));
|
||||
std::unordered_set<NodeIdType> exclude_node;
|
||||
std::transform(neighbors.begin(), neighbors.end(),
|
||||
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_node, exclude_node.begin()),
|
||||
[](const NodeIdType node) { return node; });
|
||||
auto itr = node_type_map_.find(neg_neighbor_type);
|
||||
if (itr == node_type_map_.end()) {
|
||||
std::string err_msg = "Invalid node type:" + std::to_string(neg_neighbor_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
neighbors_vec[node_idx].emplace_back(node->id());
|
||||
if (itr->second.size() > exclude_node.size()) {
|
||||
while (neighbors_vec[node_idx].size() < samples_num + 1) {
|
||||
RETURN_IF_NOT_OK(NegativeSample(itr->second, exclude_node, samples_num - neighbors_vec[node_idx].size(),
|
||||
&neighbors_vec[node_idx]));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
|
||||
<< " neg_neighbor_type:" << neg_neighbor_type;
|
||||
// If there are no negative neighbors, they are filled with kDefaultNodeId
|
||||
for (int32_t i = 0; i < samples_num; ++i) {
|
||||
neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors_vec, DataType(DataType::DE_INT32), out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -154,7 +263,7 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
|
|||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Inpude feature_types is empty");
|
||||
TensorRow tensors;
|
||||
for (auto f_type : feature_types) {
|
||||
for (const auto &f_type : feature_types) {
|
||||
std::shared_ptr<Feature> default_feature;
|
||||
// If no feature can be obtained, fill in the default value
|
||||
RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature));
|
||||
|
@ -169,18 +278,14 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
|
|||
|
||||
dsize_t index = 0;
|
||||
for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
|
||||
auto itr = node_id_map_.find(*node_itr);
|
||||
std::shared_ptr<Feature> feature;
|
||||
if (itr != node_id_map_.end()) {
|
||||
if (!itr->second->GetFeatures(f_type, &feature).IsOk()) {
|
||||
feature = default_feature;
|
||||
}
|
||||
if (*node_itr == kDefaultNodeId) {
|
||||
feature = default_feature;
|
||||
} else {
|
||||
if (*node_itr == kDefaultNodeId) {
|
||||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node));
|
||||
if (!node->GetFeatures(f_type, &feature).IsOk()) {
|
||||
feature = default_feature;
|
||||
} else {
|
||||
std::string err_msg = "Invalid node id:" + std::to_string(*node_itr);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value()));
|
||||
|
@ -209,35 +314,54 @@ Status Graph::Init() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info) {
|
||||
node_info->reserve(node_type_map_.size());
|
||||
for (auto node : node_type_map_) {
|
||||
NodeMetaInfo n_info;
|
||||
n_info.type = node.first;
|
||||
n_info.num = node.second.size();
|
||||
auto itr = node_feature_map_.find(node.first);
|
||||
if (itr != node_feature_map_.end()) {
|
||||
for (auto f_type : itr->second) {
|
||||
n_info.feature_type.push_back(f_type);
|
||||
}
|
||||
std::sort(n_info.feature_type.begin(), n_info.feature_type.end());
|
||||
}
|
||||
node_info->push_back(n_info);
|
||||
Status Graph::GetMetaInfo(MetaInfo *meta_info) {
|
||||
meta_info->node_type.resize(node_type_map_.size());
|
||||
std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(),
|
||||
[](auto itr) { return itr.first; });
|
||||
std::sort(meta_info->node_type.begin(), meta_info->node_type.end());
|
||||
|
||||
meta_info->edge_type.resize(edge_type_map_.size());
|
||||
std::transform(edge_type_map_.begin(), edge_type_map_.end(), meta_info->edge_type.begin(),
|
||||
[](auto itr) { return itr.first; });
|
||||
std::sort(meta_info->edge_type.begin(), meta_info->edge_type.end());
|
||||
|
||||
for (const auto &node : node_type_map_) {
|
||||
meta_info->node_num[node.first] = node.second.size();
|
||||
}
|
||||
|
||||
edge_info->reserve(edge_type_map_.size());
|
||||
for (auto edge : edge_type_map_) {
|
||||
EdgeMetaInfo e_info;
|
||||
e_info.type = edge.first;
|
||||
e_info.num = edge.second.size();
|
||||
auto itr = edge_feature_map_.find(edge.first);
|
||||
if (itr != edge_feature_map_.end()) {
|
||||
for (auto f_type : itr->second) {
|
||||
e_info.feature_type.push_back(f_type);
|
||||
}
|
||||
}
|
||||
edge_info->push_back(e_info);
|
||||
for (const auto &edge : edge_type_map_) {
|
||||
meta_info->edge_num[edge.first] = edge.second.size();
|
||||
}
|
||||
|
||||
for (const auto &node_feature : node_feature_map_) {
|
||||
for (auto type : node_feature.second) {
|
||||
meta_info->node_feature_type.emplace_back(type);
|
||||
}
|
||||
}
|
||||
std::sort(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end());
|
||||
auto unique_node = std::unique(meta_info->node_feature_type.begin(), meta_info->node_feature_type.end());
|
||||
meta_info->node_feature_type.erase(unique_node, meta_info->node_feature_type.end());
|
||||
|
||||
for (const auto &edge_feature : edge_feature_map_) {
|
||||
for (const auto &type : edge_feature.second) {
|
||||
meta_info->edge_feature_type.emplace_back(type);
|
||||
}
|
||||
}
|
||||
std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end());
|
||||
auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end());
|
||||
meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GraphInfo(py::dict *out) {
|
||||
MetaInfo meta_info;
|
||||
RETURN_IF_NOT_OK(GetMetaInfo(&meta_info));
|
||||
(*out)["node_type"] = py::cast(meta_info.node_type);
|
||||
(*out)["edge_type"] = py::cast(meta_info.edge_type);
|
||||
(*out)["node_num"] = py::cast(meta_info.node_num);
|
||||
(*out)["edge_num"] = py::cast(meta_info.edge_num);
|
||||
(*out)["node_feature_type"] = py::cast(meta_info.node_feature_type);
|
||||
(*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -250,6 +374,18 @@ Status Graph::LoadNodeAndEdge() {
|
|||
&node_feature_map_, &edge_feature_map_, &default_feature_map_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
|
||||
auto itr = node_id_map_.find(id);
|
||||
if (itr == node_id_map_.end()) {
|
||||
std::string err_msg = "Invalid node id:" + std::to_string(id);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
*node = itr->second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
@ -33,24 +34,13 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
struct NodeMetaInfo {
|
||||
NodeType type;
|
||||
NodeIdType num;
|
||||
std::vector<FeatureType> feature_type;
|
||||
NodeMetaInfo() {
|
||||
type = 0;
|
||||
num = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct EdgeMetaInfo {
|
||||
EdgeType type;
|
||||
EdgeIdType num;
|
||||
std::vector<FeatureType> feature_type;
|
||||
EdgeMetaInfo() {
|
||||
type = 0;
|
||||
num = 0;
|
||||
}
|
||||
struct MetaInfo {
|
||||
std::vector<NodeType> node_type;
|
||||
std::vector<EdgeType> edge_type;
|
||||
std::map<NodeType, NodeIdType> node_num;
|
||||
std::map<EdgeType, EdgeIdType> edge_num;
|
||||
std::vector<FeatureType> node_feature_type;
|
||||
std::vector<FeatureType> edge_feature_type;
|
||||
};
|
||||
|
||||
class Graph {
|
||||
|
@ -62,19 +52,23 @@ class Graph {
|
|||
|
||||
~Graph() = default;
|
||||
|
||||
// Get the nodes from the graph.
|
||||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out);
|
||||
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get the edges from the graph.
|
||||
// Get all edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out);
|
||||
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get the node id from the edge.
|
||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||
// @return Status - The error code return
|
||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
|
@ -86,10 +80,24 @@ class Graph {
|
|||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
|
||||
Status GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
|
||||
// Get sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status - The error code return
|
||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get negative sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||
// @return Status - The error code return
|
||||
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
|
||||
NodeIdType default_node, std::shared_ptr<Tensor> *out);
|
||||
|
||||
|
@ -112,10 +120,12 @@ class Graph {
|
|||
TensorRow *out);
|
||||
|
||||
// Get meta information of graph
|
||||
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node
|
||||
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge
|
||||
// @param MetaInfo *meta_info - Returned meta information
|
||||
// @return Status - The error code return
|
||||
Status GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info);
|
||||
Status GetMetaInfo(MetaInfo *meta_info);
|
||||
|
||||
// Return meta information to python layer
|
||||
Status GraphInfo(py::dict *out);
|
||||
|
||||
Status Init();
|
||||
|
||||
|
@ -146,8 +156,24 @@ class Graph {
|
|||
// @return Status - The error code return
|
||||
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
|
||||
|
||||
// Find node object using node id
|
||||
// @param NodeIdType id -
|
||||
// @param std::shared_ptr<Node> *node - Returned node object
|
||||
// @return Status - The error code return
|
||||
Status GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node);
|
||||
|
||||
// Negative sampling
|
||||
// @param std::vector<NodeIdType> &input_data - The data set to be sampled
|
||||
// @param std::unordered_set<NodeIdType> &exclude_data - Data to be excluded
|
||||
// @param int32_t samples_num -
|
||||
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
|
||||
// @return Status - The error code return
|
||||
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
|
||||
int32_t samples_num, std::vector<NodeIdType> *out_samples);
|
||||
|
||||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
std::mt19937 rnd_;
|
||||
|
||||
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
|
||||
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
|
||||
|
|
|
@ -20,12 +20,13 @@
|
|||
#include <utility>
|
||||
|
||||
#include "dataset/engine/gnn/edge.h"
|
||||
#include "dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {}
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); }
|
||||
|
||||
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
|
@ -38,21 +39,49 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
|
|||
}
|
||||
}
|
||||
|
||||
Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) {
|
||||
Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
if (samples_num == -1) {
|
||||
// Return all neighbors
|
||||
neighbors.resize(itr->second.size() + 1);
|
||||
neighbors[0] = id_;
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
} else {
|
||||
neighbors.resize(itr->second.size() + 1);
|
||||
neighbors[0] = id_;
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
|
||||
neighbors.emplace_back(id_);
|
||||
}
|
||||
*out_neighbors = std::move(neighbors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out) {
|
||||
std::vector<NodeIdType> shuffled_id(neighbors.size());
|
||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size()));
|
||||
for (int32_t i = 0; i < num; ++i) {
|
||||
out->emplace_back(neighbors[shuffled_id[i]]->id());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_neighbors) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
neighbors.reserve(samples_num);
|
||||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
while (neighbors.size() < samples_num) {
|
||||
RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors));
|
||||
}
|
||||
} else {
|
||||
neighbors.push_back(id_);
|
||||
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
|
||||
MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
|
||||
// If there are no neighbors, they are filled with kDefaultNodeId
|
||||
for (int32_t i = 0; i < samples_num; ++i) {
|
||||
neighbors.emplace_back(kDefaultNodeId);
|
||||
}
|
||||
}
|
||||
*out_neighbors = std::move(neighbors);
|
||||
return Status::OK();
|
||||
|
|
|
@ -43,12 +43,19 @@ class LocalNode : public Node {
|
|||
// @return Status - The error code return
|
||||
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
|
||||
|
||||
// Get the neighbors of a node
|
||||
// Get the all neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) override;
|
||||
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override;
|
||||
|
||||
// Get the sampled neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_neighbors) override;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
|
@ -61,6 +68,10 @@ class LocalNode : public Node {
|
|||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out);
|
||||
|
||||
std::mt19937 rnd_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_;
|
||||
};
|
||||
|
|
|
@ -52,12 +52,19 @@ class Node {
|
|||
// @return Status - The error code return
|
||||
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
|
||||
|
||||
// Get the neighbors of a node
|
||||
// Get the all neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
|
||||
// Get the sampled neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
|
|
|
@ -20,8 +20,9 @@ import numpy as np
|
|||
from mindspore._c_dataengine import Graph
|
||||
from mindspore._c_dataengine import Tensor
|
||||
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \
|
||||
check_gnn_get_node_feature
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
|
||||
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature
|
||||
|
||||
|
||||
class GraphData:
|
||||
|
@ -60,7 +61,44 @@ class GraphData:
|
|||
Raises:
|
||||
TypeError: If `node_type` is not integer.
|
||||
"""
|
||||
return self._graph.get_nodes(node_type, -1).as_array()
|
||||
return self._graph.get_all_nodes(node_type).as_array()
|
||||
|
||||
@check_gnn_get_all_edges
|
||||
def get_all_edges(self, edge_type):
|
||||
"""
|
||||
Get all edges in the graph.
|
||||
|
||||
Args:
|
||||
edge_type (int): Specify the type of edge.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of edges.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.get_all_edges(0)
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_type` is not integer.
|
||||
"""
|
||||
return self._graph.get_all_edges(edge_type).as_array()
|
||||
|
||||
@check_gnn_get_nodes_from_edges
|
||||
def get_nodes_from_edges(self, edge_list):
|
||||
"""
|
||||
Get nodes from the edges.
|
||||
|
||||
Args:
|
||||
edge_list (list or numpy.ndarray): The given list of edges.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
"""
|
||||
return self._graph.get_nodes_from_edges(edge_list).as_array()
|
||||
|
||||
@check_gnn_get_all_neighbors
|
||||
def get_all_neighbors(self, node_list, neighbor_type):
|
||||
|
@ -86,6 +124,58 @@ class GraphData:
|
|||
"""
|
||||
return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
|
||||
|
||||
@check_gnn_get_sampled_neighbors
|
||||
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
|
||||
"""
|
||||
Get sampled neighbor information, maximum support 6-hop sampling.
|
||||
|
||||
Args:
|
||||
node_list (list or numpy.ndarray): The given list of nodes.
|
||||
neighbor_nums (list or numpy.ndarray): Number of neighbors sampled per hop.
|
||||
neighbor_types (list or numpy.ndarray): Neighbor type sampled per hop.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.get_all_nodes(0)
|
||||
>>> neighbors = data_graph.get_all_neighbors(nodes, [2, 2], [0, 0])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `neighbor_nums` is not list or ndarray.
|
||||
TypeError: If `neighbor_types` is not list or ndarray.
|
||||
"""
|
||||
return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array()
|
||||
|
||||
@check_gnn_get_neg_sampled_neighbors
|
||||
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
|
||||
"""
|
||||
Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
node_list (list or numpy.ndarray): The given list of nodes.
|
||||
neg_neighbor_num (int): Number of neighbors sampled.
|
||||
neg_neighbor_type (int): Specify the type of negative neighbor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.get_all_nodes(0)
|
||||
>>> neg_neighbors = data_graph.get_neg_sampled_neighbors(nodes, 5, 0)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `neg_neighbor_num` is not integer.
|
||||
TypeError: If `neg_neighbor_type` is not integer.
|
||||
"""
|
||||
return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array()
|
||||
|
||||
@check_gnn_get_node_feature
|
||||
def get_node_feature(self, node_list, feature_types):
|
||||
"""
|
||||
|
@ -111,3 +201,13 @@ class GraphData:
|
|||
if isinstance(node_list, list):
|
||||
node_list = np.array(node_list, dtype=np.int32)
|
||||
return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]
|
||||
|
||||
def graph_info(self):
|
||||
"""
|
||||
Get the meta information of the graph, including the number of nodes, the type of nodes,
|
||||
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.
|
||||
Returns:
|
||||
Dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
||||
node_feature_type and edge_feature_type.
|
||||
"""
|
||||
return self._graph.graph_info()
|
||||
|
|
|
@ -1153,6 +1153,36 @@ def check_gnn_get_all_nodes(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_all_edges(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_edges` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_type; required argument
|
||||
check_type(param_dict.get("edge_type"), 'edge_type', int)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_nodes_from_edges(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_nodes_from_edges` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check edge_list; required argument
|
||||
check_gnn_list_or_ndarray(param_dict.get("edge_list"), 'edge_list')
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_all_neighbors(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
|
||||
|
||||
|
@ -1171,6 +1201,61 @@ def check_gnn_get_all_neighbors(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_sampled_neighbors(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_sampled_neighbors` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
||||
|
||||
# check neighbor_nums; required argument
|
||||
neighbor_nums = param_dict.get("neighbor_nums")
|
||||
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
|
||||
if len(neighbor_nums) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
|
||||
'neighbor_nums', len(neighbor_nums)))
|
||||
|
||||
# check neighbor_types; required argument
|
||||
neighbor_types = param_dict.get("neighbor_types")
|
||||
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
||||
if len(neighbor_nums) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
|
||||
'neighbor_types', len(neighbor_types)))
|
||||
|
||||
if len(neighbor_nums) != len(neighbor_types):
|
||||
raise ValueError(
|
||||
"The number of members of neighbor_nums and neighbor_types is inconsistent")
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_neg_sampled_neighbors(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_neg_sampled_neighbors` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
||||
|
||||
# check neg_neighbor_num; required argument
|
||||
check_type(param_dict.get("neg_neighbor_num"), 'neg_neighbor_num', int)
|
||||
|
||||
# check neg_neighbor_type; required argument
|
||||
check_type(param_dict.get("neg_neighbor_type"),
|
||||
'neg_neighbor_type', int)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_aligned_list(param, param_name, membor_type):
|
||||
"""Check whether the structure of each member of the list is the same."""
|
||||
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
@ -45,7 +47,7 @@ TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
|
|||
&default_feature_map)
|
||||
.IsOk());
|
||||
EXPECT_EQ(n_id_map.size(), 20);
|
||||
EXPECT_EQ(e_id_map.size(), 20);
|
||||
EXPECT_EQ(e_id_map.size(), 40);
|
||||
EXPECT_EQ(n_type_map[2].size(), 10);
|
||||
EXPECT_EQ(n_type_map[1].size(), 10);
|
||||
}
|
||||
|
@ -56,14 +58,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
|||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
std::vector<NodeMetaInfo> node_info;
|
||||
std::vector<EdgeMetaInfo> edge_info;
|
||||
s = graph.GetMetaInfo(&node_info, &edge_info);
|
||||
MetaInfo meta_info;
|
||||
s = graph.GetMetaInfo(&meta_info);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(node_info.size() == 2);
|
||||
EXPECT_TRUE(meta_info.node_type.size() == 2);
|
||||
|
||||
std::shared_ptr<Tensor> nodes;
|
||||
s = graph.GetNodes(node_info[1].type, -1, &nodes);
|
||||
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::vector<NodeIdType> node_list;
|
||||
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
|
||||
|
@ -73,13 +74,13 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
|||
}
|
||||
}
|
||||
std::shared_ptr<Tensor> neighbors;
|
||||
s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors);
|
||||
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
|
||||
TensorRow features;
|
||||
s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features);
|
||||
s = graph.GetNodeFeature(nodes, meta_info.node_feature_type, &features);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(features.size() == 3);
|
||||
EXPECT_TRUE(features.size() == 4);
|
||||
EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>");
|
||||
EXPECT_TRUE(features[0]->ToString() ==
|
||||
"Tensor (shape: <10,5>, Type: int32)\n"
|
||||
|
@ -91,3 +92,106 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
|||
EXPECT_TRUE(features[2]->shape().ToString() == "<10>");
|
||||
EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
Graph graph(path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
MetaInfo meta_info;
|
||||
s = graph.GetMetaInfo(&meta_info);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(meta_info.node_type.size() == 2);
|
||||
|
||||
std::shared_ptr<Tensor> edges;
|
||||
s = graph.GetAllEdges(meta_info.edge_type[0], &edges);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::vector<EdgeIdType> edge_list;
|
||||
edge_list.resize(edges->Size());
|
||||
std::transform(edges->begin<EdgeIdType>(), edges->end<EdgeIdType>(), edge_list.begin(),
|
||||
[](const EdgeIdType edge) { return edge; });
|
||||
|
||||
std::shared_ptr<Tensor> nodes;
|
||||
s = graph.GetNodesFromEdges(edge_list, &nodes);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::unordered_set<NodeIdType> node_set;
|
||||
std::vector<NodeIdType> node_list;
|
||||
int index = 0;
|
||||
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
|
||||
index++;
|
||||
if (index % 2 == 0) {
|
||||
continue;
|
||||
}
|
||||
node_set.emplace(*itr);
|
||||
if (node_set.size() >= 5) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
node_list.resize(node_set.size());
|
||||
std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
|
||||
|
||||
std::shared_ptr<Tensor> neighbors;
|
||||
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
|
||||
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
|
||||
std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
Graph graph(path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
MetaInfo meta_info;
|
||||
s = graph.GetMetaInfo(&meta_info);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(meta_info.node_type.size() == 2);
|
||||
|
||||
std::shared_ptr<Tensor> nodes;
|
||||
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::vector<NodeIdType> node_list;
|
||||
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
|
||||
node_list.push_back(*itr);
|
||||
if (node_list.size() >= 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::shared_ptr<Tensor> neg_neighbors;
|
||||
s = graph.GetNegSampledNeighbors(node_list, 3, meta_info.node_type[1], &neg_neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neg_neighbors->shape().ToString() == "<10,4>");
|
||||
|
||||
neg_neighbors.reset();
|
||||
s = graph.GetNegSampledNeighbors({}, 3, meta_info.node_type[1], &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
|
||||
|
||||
neg_neighbors.reset();
|
||||
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
|
||||
}
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import random
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
|
@ -77,8 +78,110 @@ def test_graphdata_getnodefeature_input_check():
|
|||
g.get_node_feature(input_list, [1, "a"])
|
||||
|
||||
|
||||
def test_graphdata_getsampledneighbors():
|
||||
g = ds.GraphData(DATASET_FILE, 1)
|
||||
edges = g.get_all_edges(0)
|
||||
nodes = g.get_nodes_from_edges(edges)
|
||||
assert len(nodes) == 40
|
||||
neighbor = g.get_sampled_neighbors(
|
||||
np.unique(nodes[0:21, 0]), [2, 3], [2, 1])
|
||||
assert neighbor.shape == (10, 9)
|
||||
|
||||
|
||||
def test_graphdata_getnegsampledneighbors():
|
||||
g = ds.GraphData(DATASET_FILE, 2)
|
||||
nodes = g.get_all_nodes(1)
|
||||
assert len(nodes) == 10
|
||||
neighbor = g.get_neg_sampled_neighbors(nodes, 5, 2)
|
||||
assert neighbor.shape == (10, 6)
|
||||
|
||||
|
||||
def test_graphdata_graphinfo():
|
||||
g = ds.GraphData(DATASET_FILE, 2)
|
||||
graph_info = g.graph_info()
|
||||
assert graph_info['node_type'] == [1, 2]
|
||||
assert graph_info['edge_type'] == [0]
|
||||
assert graph_info['node_num'] == {1: 10, 2: 10}
|
||||
assert graph_info['edge_num'] == {0: 40}
|
||||
assert graph_info['node_feature_type'] == [1, 2, 3, 4]
|
||||
assert graph_info['edge_feature_type'] == []
|
||||
|
||||
|
||||
class RandomBatchedSampler(ds.Sampler):
|
||||
# RandomBatchedSampler generate random sequence without replacement in a batched manner
|
||||
def __init__(self, index_range, num_edges_per_sample):
|
||||
super().__init__()
|
||||
self.index_range = index_range
|
||||
self.num_edges_per_sample = num_edges_per_sample
|
||||
|
||||
def __iter__(self):
|
||||
indices = [i+1 for i in range(self.index_range)]
|
||||
# Reset random seed here if necessary
|
||||
# random.seed(0)
|
||||
random.shuffle(indices)
|
||||
for i in range(0, self.index_range, self.num_edges_per_sample):
|
||||
# Drop reminder
|
||||
if i + self.num_edges_per_sample <= self.index_range:
|
||||
yield indices[i: i + self.num_edges_per_sample]
|
||||
|
||||
|
||||
class GNNGraphDataset():
|
||||
def __init__(self, g, batch_num):
|
||||
self.g = g
|
||||
self.batch_num = batch_num
|
||||
|
||||
def __len__(self):
|
||||
# Total sample size of GNN dataset
|
||||
# In this case, the size should be total_num_edges/num_edges_per_sample
|
||||
return self.g.graph_info()['edge_num'][0] // self.batch_num
|
||||
|
||||
def __getitem__(self, index):
|
||||
# index will be a list of indices yielded from RandomBatchedSampler
|
||||
# Fetch edges/nodes/samples/features based on indices
|
||||
nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
|
||||
nodes = nodes[:, 0]
|
||||
neg_nodes = self.g.get_neg_sampled_neighbors(
|
||||
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
|
||||
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
|
||||
2, 2], neighbor_types=[2, 1])
|
||||
neg_nodes_neighbors = self.g.get_sampled_neighbors(
|
||||
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
|
||||
nodes_neighbors_features = self.g.get_node_feature(
|
||||
node_list=nodes_neighbors, feature_types=[2, 3])
|
||||
neg_neighbors_features = self.g.get_node_feature(
|
||||
node_list=neg_nodes_neighbors, feature_types=[2, 3])
|
||||
return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
|
||||
|
||||
|
||||
def test_graphdata_generatordataset():
|
||||
g = ds.GraphData(DATASET_FILE)
|
||||
batch_num = 2
|
||||
edge_num = g.graph_info()['edge_num'][0]
|
||||
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
|
||||
dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
|
||||
sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4)
|
||||
dataset = dataset.repeat(2)
|
||||
itr = dataset.create_dict_iterator()
|
||||
i = 0
|
||||
for data in itr:
|
||||
assert data['neighbors'].shape == (2, 7)
|
||||
assert data['neg_neighbors'].shape == (6, 7)
|
||||
assert data['neighbors_features'].shape == (2, 7)
|
||||
assert data['neg_neighbors_features'].shape == (6, 7)
|
||||
i += 1
|
||||
assert i == 40
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_graphdata_getfullneighbor()
|
||||
logger.info('test_graphdata_getfullneighbor Ended.\n')
|
||||
test_graphdata_getnodefeature_input_check()
|
||||
logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
|
||||
test_graphdata_getsampledneighbors()
|
||||
logger.info('test_graphdata_getsampledneighbors Ended.\n')
|
||||
test_graphdata_getnegsampledneighbors()
|
||||
logger.info('test_graphdata_getnegsampledneighbors Ended.\n')
|
||||
test_graphdata_graphinfo()
|
||||
logger.info('test_graphdata_graphinfo Ended.\n')
|
||||
test_graphdata_generatordataset()
|
||||
logger.info('test_graphdata_generatordataset Ended.\n')
|
||||
|
|
Loading…
Reference in New Issue