Add get edges from nodes function for gnn network

This commit is contained in:
Zhenglong Li 2021-04-09 15:52:07 +08:00
parent 5556b12da4
commit f3895f2b09
20 changed files with 263 additions and 13 deletions

View File

@ -57,6 +57,12 @@ PYBIND_REGISTER(
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out;
})
.def("get_edges_from_nodes",
[](gnn::GraphData &g, std::vector<std::pair<gnn::NodeIdType, gnn::NodeIdType>> node_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetEdgesFromNodes(node_list, &out));
return out;
})
.def("get_all_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
std::shared_ptr<Tensor> out;

View File

@ -50,12 +50,13 @@ enum GnnOpName {
GET_ALL_NODES = 0;
GET_ALL_EDGES = 1;
GET_NODES_FROM_EDGES = 2;
GET_ALL_NEIGHBORS = 3;
GET_SAMPLED_NEIGHBORS = 4;
GET_NEG_SAMPLED_NEIGHBORS = 5;
RANDOM_WALK = 6;
GET_NODE_FEATURE = 7;
GET_EDGE_FEATURE = 8;
GET_EDGES_FROM_NODES = 3;
GET_ALL_NEIGHBORS = 4;
GET_SAMPLED_NEIGHBORS = 5;
GET_NEG_SAMPLED_NEIGHBORS = 6;
RANDOM_WALK = 7;
GET_NODE_FEATURE = 8;
GET_EDGE_FEATURE = 9;
}
message GnnRandomWalkPb {
@ -64,6 +65,11 @@ message GnnRandomWalkPb {
int32 default_id = 3;
}
message IdPairPb {
int32 src_id = 1;
int32 dst_id = 2;
}
message GnnGraphDataRequestPb {
GnnOpName op_name = 1;
repeated int32 id = 2; // node id or edge id
@ -72,6 +78,7 @@ message GnnGraphDataRequestPb {
TensorPb id_tensor = 5; // input ids ,node id or edge id
GnnRandomWalkPb random_walk = 6;
int32 strategy = 7;
repeated IdPairPb node_pair = 8;
}
message GnnGraphDataResponsePb {

View File

@ -62,6 +62,13 @@ class GraphData {
// @return Status The status code returned
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
// Get the edge id from connected node pair
// @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The status code that indicate the result of function execution
virtual Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
std::shared_ptr<Tensor> *out) = 0;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported

View File

@ -120,6 +120,25 @@ Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_li
return Status::OK();
}
Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_EDGES_FROM_NODES);
for (const auto &pair_node_id : node_list) {
IdPairPb *proto_pair(request.add_node_pair());
proto_pair->set_src_id(static_cast<google::protobuf::int32>(pair_node_id.first));
proto_pair->set_dst_id(static_cast<google::protobuf::int32>(pair_node_id.second));
}
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)

View File

@ -72,6 +72,13 @@ class GraphDataClient : public GraphData {
// @return Status The status code returned
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// Get the edge id from connected node pair
// @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The status code that indicate the result of function execution
Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported

View File

@ -128,6 +128,30 @@ Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list
return Status::OK();
}
Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
std::shared_ptr<Tensor> *out) {
if (node_list.empty()) {
RETURN_STATUS_UNEXPECTED("Input node list is empty.");
}
std::vector<std::vector<EdgeIdType>> edge_list;
edge_list.reserve(node_list.size());
for (const auto &node_id : node_list) {
std::shared_ptr<Node> src_node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node));
EdgeIdType *edge_id = nullptr;
src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id);
std::vector<EdgeIdType> connection_edge = {*edge_id};
edge_list.emplace_back(std::move(connection_edge));
}
RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>(edge_list, DataType(DataType::DE_INT32), out));
return Status::OK();
}
Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");

View File

@ -66,6 +66,13 @@ class GraphDataImpl : public GraphData {
// @return Status The status code returned
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// Get the edge id from connected node pair
// @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The status code that indicate the result of function execution
Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported

View File

@ -17,6 +17,7 @@
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/gnn/tensor_proto.h"
@ -31,6 +32,7 @@ static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = {
{GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes},
{GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges},
{GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges},
{GET_EDGES_FROM_NODES, &GraphDataServiceImpl::GetEdgesFromNodes},
{GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors},
{GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors},
{GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors},
@ -189,6 +191,27 @@ Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *requ
return Status::OK();
}
Status GraphDataServiceImpl::GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->node_pair_size() > 0, "The input node pair id list is empty.");
std::vector<std::pair<NodeIdType, NodeIdType>> node_list;
node_list.resize(request->node_pair().size());
std::transform(
request->node_pair().begin(), request->node_pair().end(), node_list.begin(), [](const auto &node_pair_id) {
auto cur_pair =
std::make_pair(static_cast<NodeIdType>(node_pair_id.src_id()), static_cast<NodeIdType>(node_pair_id.dst_id()));
return cur_pair;
});
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetEdgesFromNodes(node_list, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");

View File

@ -50,6 +50,7 @@ class GraphDataServiceImpl {
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);

View File

@ -61,10 +61,14 @@ Status GraphLoader::GetNodesAndEdges() {
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight()));
RETURN_IF_NOT_OK(src_itr->second->AddAdjacent(dst_itr->second, edge_ptr));
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
dq.pop_front();

View File

@ -131,6 +131,26 @@ Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightTyp
return Status::OK();
}
Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) {
auto node_id = node->id();
auto edge_id = edge->id();
adjacent_nodes_.insert({node_id, edge_id});
return Status::OK();
}
Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) {
auto itr = adjacent_nodes_.find(adj_node_id);
if (itr != adjacent_nodes_.end()) {
(*out_edge_id) = &(itr->second);
} else {
(*out_edge_id) = new EdgeIdType(-1);
MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node.";
}
return Status::OK();
}
Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) {
auto itr = features_.find(feature->type());
if (itr != features_.end()) {

View File

@ -65,6 +65,18 @@ class LocalNode : public Node {
// @return Status The status code returned
Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override;
// Add adjacent node and relative edge for source node
// @param std::shared_ptr<Node> node - the node to be inserted into adjacent table
// @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node
// @return Status - The status code that indicate the result of function execution
Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) override;
// Get relative connecting edge of adjacent node by node id
// @param NodeIdType - The id of adjacent node to be processed
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
// @return Status - The status code that indicate the result of function execution
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override;
// Update feature of node
// @param std::shared_ptr<Feature> feature -
// @return Status The status code returned
@ -81,6 +93,7 @@ class LocalNode : public Node {
std::mt19937 rnd_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
std::unordered_map<NodeIdType, EdgeIdType> adjacent_nodes_;
};
} // namespace gnn
} // namespace dataset

View File

@ -29,9 +29,12 @@ namespace gnn {
using NodeType = int8_t;
using NodeIdType = int32_t;
using WeightType = float;
using EdgeIdType = int32_t;
constexpr NodeIdType kDefaultNodeId = -1;
class Edge;
class Node {
public:
// Constructor
@ -78,6 +81,18 @@ class Node {
// @return Status The status code returned
virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0;
// Add adjacent node and relative edge for source node
// @param std::shared_ptr<Node> node - the node to be inserted into adjacent table
// @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node
// @return Status - The status code that indicate the result of function execution
virtual Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) = 0;
// Get relative connecting edge of adjacent node by node id
// @param NodeIdType - The id of adjacent node to be processed
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
// @return Status - The status code that indicate the result of function execution
virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0;
// Update feature of node
// @param std::shared_ptr<Feature> feature -
// @return Status The status code returned

View File

@ -102,8 +102,7 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share
unsigned char *ret_ptr = data.get();
std::shared_ptr<DvppDataInfo> DecodeOut(process.Get_Decode_DeviceData());
dsize_t dvpp_length = DecodeOut->dataSize;
// dsize_t decode_height = DecodeOut->height;
// dsize_t decode_width = DecodeOut->width;
const TensorShape dvpp_shape({dvpp_length, 1, 1});
const DataType dvpp_data_type(DataType::DE_UINT8);
mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output);

View File

@ -360,6 +360,38 @@ def validate_dataset_param_value(param_list, param_dict, param_type):
type_check(param_dict.get(param_name), (param_type,), param_name)
def check_gnn_list_of_pair_or_ndarray(param, param_name):
"""
Check if the input parameter is a list of tuple or numpy.ndarray.
Args:
param (Union[list[tuple], nd.ndarray]): param.
param_name (str): param_name.
Returns:
Exception: TypeError if error.
"""
type_check(param, (list, np.ndarray), param_name)
if isinstance(param, list):
param_names = ["pair_{0}".format(i) for i in range(len(param))]
type_check_list(param, (tuple,), param_names)
for idx, pair in enumerate(param):
if not len(pair) == 2:
raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format(
param_names[idx], len(pair)))
column_names = ["element_{0}".format(i) for i in range(len(pair))]
type_check_list(pair, (int,), column_names)
elif isinstance(param, np.ndarray):
if param.ndim != 2:
raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim))
if param.shape[1] != 2:
raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format(
param_name, param.shape[1]))
if not param.dtype == np.int32:
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
param_name, param.dtype))
def check_gnn_list_or_ndarray(param, param_name):
"""
Check if the input parameter is list or numpy.ndarray.

View File

@ -26,9 +26,9 @@ from mindspore._c_dataengine import Tensor
from mindspore._c_dataengine import SamplingStrategy as Sampling
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \
check_gnn_random_walk
check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \
check_gnn_get_sampled_neighbors, check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, \
check_gnn_get_edge_feature, check_gnn_random_walk
class SamplingStrategy(IntEnum):
@ -162,6 +162,27 @@ class GraphData:
raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_nodes_from_edges(edge_list).as_array()
@check_gnn_get_edges_from_nodes
def get_edges_from_nodes(self, node_list):
"""
Get edges from the nodes.
Args:
node_list (Union[list[tuple], numpy.ndarray]): The given list of pair nodes ID.
Returns:
numpy.ndarray, array of edgs ID.
Examples:
>>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)])
Raises:
TypeError: If `edge_list` is not list or ndarray.
"""
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_edges_from_nodes(node_list).as_array()
@check_gnn_get_all_neighbors
def get_all_neighbors(self, node_list, neighbor_type):
"""

View File

@ -25,8 +25,8 @@ import numpy as np
from mindspore._c_expression import typing
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
check_columns, check_pos_int32, check_valid_str
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str
from . import datasets
from . import samplers
@ -1090,6 +1090,19 @@ def check_gnn_get_nodes_from_edges(method):
return new_method
def check_gnn_get_edges_from_nodes(method):
"""A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function."""
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_of_pair_or_ndarray(node_list, "node_list")
return method(self, *args, **kwargs)
return new_method
def check_gnn_get_all_neighbors(method):
"""A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""

View File

@ -95,6 +95,21 @@ class MindDataTestGNNGraph : public UT::Common {
}
};
TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
std::string path = "data/mindrecord/testGraphData/testdata";
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208},
{110, 201}, {204, 105}, {208, 108}};
std::shared_ptr<Tensor> edges;
s = graph.GetEdgesFromNodes(src_dst_list, &edges);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]");
}
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
GraphDataImpl graph(path, 1);

View File

@ -241,6 +241,18 @@ def test_graphdata_getedgefeature():
assert features[1].shape == (40,)
def test_graphdata_getedgesfromnodes():
"""
Test get edges from nodes
"""
logger.info('test get_edges_from_nodes\n')
g = ds.GraphData(DATASET_FILE)
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
edges = g.get_edges_from_nodes(nodes_pair_list)
assert edges.tolist() == [1, 9, 31, 17, 20, 40]
if __name__ == '__main__':
test_graphdata_getfullneighbor()
test_graphdata_getnodefeature_input_check()
@ -251,3 +263,4 @@ if __name__ == '__main__':
test_graphdata_randomwalkdefault()
test_graphdata_randomwalk()
test_graphdata_getedgefeature()
test_graphdata_getedgesfromnodes()

View File

@ -112,6 +112,10 @@ def test_graphdata_distributed():
assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)]
edges = g.get_edges_from_nodes(nodes_pair_list)
assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37]
batch_num = 2
edge_num = g.graph_info()['edge_num'][0]
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]