forked from mindspore-Ecosystem/mindspore
Add get edges from nodes function for gnn network
This commit is contained in:
parent
5556b12da4
commit
f3895f2b09
|
@ -57,6 +57,12 @@ PYBIND_REGISTER(
|
|||
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_edges_from_nodes",
|
||||
[](gnn::GraphData &g, std::vector<std::pair<gnn::NodeIdType, gnn::NodeIdType>> node_list) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetEdgesFromNodes(node_list, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_neighbors",
|
||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
|
|
|
@ -50,12 +50,13 @@ enum GnnOpName {
|
|||
GET_ALL_NODES = 0;
|
||||
GET_ALL_EDGES = 1;
|
||||
GET_NODES_FROM_EDGES = 2;
|
||||
GET_ALL_NEIGHBORS = 3;
|
||||
GET_SAMPLED_NEIGHBORS = 4;
|
||||
GET_NEG_SAMPLED_NEIGHBORS = 5;
|
||||
RANDOM_WALK = 6;
|
||||
GET_NODE_FEATURE = 7;
|
||||
GET_EDGE_FEATURE = 8;
|
||||
GET_EDGES_FROM_NODES = 3;
|
||||
GET_ALL_NEIGHBORS = 4;
|
||||
GET_SAMPLED_NEIGHBORS = 5;
|
||||
GET_NEG_SAMPLED_NEIGHBORS = 6;
|
||||
RANDOM_WALK = 7;
|
||||
GET_NODE_FEATURE = 8;
|
||||
GET_EDGE_FEATURE = 9;
|
||||
}
|
||||
|
||||
message GnnRandomWalkPb {
|
||||
|
@ -64,6 +65,11 @@ message GnnRandomWalkPb {
|
|||
int32 default_id = 3;
|
||||
}
|
||||
|
||||
message IdPairPb {
|
||||
int32 src_id = 1;
|
||||
int32 dst_id = 2;
|
||||
}
|
||||
|
||||
message GnnGraphDataRequestPb {
|
||||
GnnOpName op_name = 1;
|
||||
repeated int32 id = 2; // node id or edge id
|
||||
|
@ -72,6 +78,7 @@ message GnnGraphDataRequestPb {
|
|||
TensorPb id_tensor = 5; // input ids ,node id or edge id
|
||||
GnnRandomWalkPb random_walk = 6;
|
||||
int32 strategy = 7;
|
||||
repeated IdPairPb node_pair = 8;
|
||||
}
|
||||
|
||||
message GnnGraphDataResponsePb {
|
||||
|
|
|
@ -62,6 +62,13 @@ class GraphData {
|
|||
// @return Status The status code returned
|
||||
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get the edge id from connected node pair
|
||||
// @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
virtual Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
|
|
|
@ -120,6 +120,25 @@ Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_li
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
GnnGraphDataRequestPb request;
|
||||
GnnGraphDataResponsePb response;
|
||||
|
||||
request.set_op_name(GET_EDGES_FROM_NODES);
|
||||
|
||||
for (const auto &pair_node_id : node_list) {
|
||||
IdPairPb *proto_pair(request.add_node_pair());
|
||||
proto_pair->set_src_id(static_cast<google::protobuf::int32>(pair_node_id.first));
|
||||
proto_pair->set_dst_id(static_cast<google::protobuf::int32>(pair_node_id.second));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
|
|
|
@ -72,6 +72,13 @@ class GraphDataClient : public GraphData {
|
|||
// @return Status The status code returned
|
||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get the edge id from connected node pair
|
||||
// @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
|
|
|
@ -128,6 +128,30 @@ Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
if (node_list.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Input node list is empty.");
|
||||
}
|
||||
|
||||
std::vector<std::vector<EdgeIdType>> edge_list;
|
||||
edge_list.reserve(node_list.size());
|
||||
|
||||
for (const auto &node_id : node_list) {
|
||||
std::shared_ptr<Node> src_node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node));
|
||||
|
||||
EdgeIdType *edge_id = nullptr;
|
||||
src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id);
|
||||
|
||||
std::vector<EdgeIdType> connection_edge = {*edge_id};
|
||||
edge_list.emplace_back(std::move(connection_edge));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<EdgeIdType>(edge_list, DataType(DataType::DE_INT32), out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
|
|
|
@ -66,6 +66,13 @@ class GraphDataImpl : public GraphData {
|
|||
// @return Status The status code returned
|
||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get the edge id from connected node pair
|
||||
// @param std::vector<std::pair<NodeIdType, NodeIdType>> node_list - List of pair nodes
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
Status GetEdgesFromNodes(const std::vector<std::pair<NodeIdType, NodeIdType>> &node_list,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/tensor_proto.h"
|
||||
|
@ -31,6 +32,7 @@ static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = {
|
|||
{GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes},
|
||||
{GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges},
|
||||
{GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges},
|
||||
{GET_EDGES_FROM_NODES, &GraphDataServiceImpl::GetEdgesFromNodes},
|
||||
{GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors},
|
||||
{GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors},
|
||||
{GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors},
|
||||
|
@ -189,6 +191,27 @@ Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *requ
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataServiceImpl::GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(request->node_pair_size() > 0, "The input node pair id list is empty.");
|
||||
|
||||
std::vector<std::pair<NodeIdType, NodeIdType>> node_list;
|
||||
node_list.resize(request->node_pair().size());
|
||||
|
||||
std::transform(
|
||||
request->node_pair().begin(), request->node_pair().end(), node_list.begin(), [](const auto &node_pair_id) {
|
||||
auto cur_pair =
|
||||
std::make_pair(static_cast<NodeIdType>(node_pair_id.src_id()), static_cast<NodeIdType>(node_pair_id.dst_id()));
|
||||
return cur_pair;
|
||||
});
|
||||
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(graph_data_impl_->GetEdgesFromNodes(node_list, &tensor));
|
||||
|
||||
TensorPb *result = response->add_result_data();
|
||||
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
|
||||
|
|
|
@ -50,6 +50,7 @@ class GraphDataServiceImpl {
|
|||
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetEdgesFromNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||
|
|
|
@ -61,10 +61,14 @@ Status GraphLoader::GetNodesAndEdges() {
|
|||
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
|
||||
RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
|
||||
auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
|
||||
|
||||
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight()));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddAdjacent(dst_itr->second, edge_ptr));
|
||||
|
||||
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
|
||||
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
|
||||
dq.pop_front();
|
||||
|
|
|
@ -131,6 +131,26 @@ Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightTyp
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) {
|
||||
auto node_id = node->id();
|
||||
auto edge_id = edge->id();
|
||||
adjacent_nodes_.insert({node_id, edge_id});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) {
|
||||
auto itr = adjacent_nodes_.find(adj_node_id);
|
||||
|
||||
if (itr != adjacent_nodes_.end()) {
|
||||
(*out_edge_id) = &(itr->second);
|
||||
} else {
|
||||
(*out_edge_id) = new EdgeIdType(-1);
|
||||
MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node.";
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) {
|
||||
auto itr = features_.find(feature->type());
|
||||
if (itr != features_.end()) {
|
||||
|
|
|
@ -65,6 +65,18 @@ class LocalNode : public Node {
|
|||
// @return Status The status code returned
|
||||
Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override;
|
||||
|
||||
// Add adjacent node and relative edge for source node
|
||||
// @param std::shared_ptr<Node> node - the node to be inserted into adjacent table
|
||||
// @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) override;
|
||||
|
||||
// Get relative connecting edge of adjacent node by node id
|
||||
// @param NodeIdType - The id of adjacent node to be processed
|
||||
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status The status code returned
|
||||
|
@ -81,6 +93,7 @@ class LocalNode : public Node {
|
|||
std::mt19937 rnd_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
|
||||
std::unordered_map<NodeIdType, EdgeIdType> adjacent_nodes_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
|
|
|
@ -29,9 +29,12 @@ namespace gnn {
|
|||
using NodeType = int8_t;
|
||||
using NodeIdType = int32_t;
|
||||
using WeightType = float;
|
||||
using EdgeIdType = int32_t;
|
||||
|
||||
constexpr NodeIdType kDefaultNodeId = -1;
|
||||
|
||||
class Edge;
|
||||
|
||||
class Node {
|
||||
public:
|
||||
// Constructor
|
||||
|
@ -78,6 +81,18 @@ class Node {
|
|||
// @return Status The status code returned
|
||||
virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0;
|
||||
|
||||
// Add adjacent node and relative edge for source node
|
||||
// @param std::shared_ptr<Node> node - the node to be inserted into adjacent table
|
||||
// @param std::shared_ptr<Edge> edge - the edge related to the adjacent node of source node
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
virtual Status AddAdjacent(const std::shared_ptr<Node> &node, const std::shared_ptr<Edge> &edge) = 0;
|
||||
|
||||
// Get relative connecting edge of adjacent node by node id
|
||||
// @param NodeIdType - The id of adjacent node to be processed
|
||||
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status The status code returned
|
||||
|
|
|
@ -102,8 +102,7 @@ Status DvppDecodePngOp::Compute(const std::shared_ptr<Tensor> &input, std::share
|
|||
unsigned char *ret_ptr = data.get();
|
||||
std::shared_ptr<DvppDataInfo> DecodeOut(process.Get_Decode_DeviceData());
|
||||
dsize_t dvpp_length = DecodeOut->dataSize;
|
||||
// dsize_t decode_height = DecodeOut->height;
|
||||
// dsize_t decode_width = DecodeOut->width;
|
||||
|
||||
const TensorShape dvpp_shape({dvpp_length, 1, 1});
|
||||
const DataType dvpp_data_type(DataType::DE_UINT8);
|
||||
mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output);
|
||||
|
|
|
@ -360,6 +360,38 @@ def validate_dataset_param_value(param_list, param_dict, param_type):
|
|||
type_check(param_dict.get(param_name), (param_type,), param_name)
|
||||
|
||||
|
||||
def check_gnn_list_of_pair_or_ndarray(param, param_name):
|
||||
"""
|
||||
Check if the input parameter is a list of tuple or numpy.ndarray.
|
||||
|
||||
Args:
|
||||
param (Union[list[tuple], nd.ndarray]): param.
|
||||
param_name (str): param_name.
|
||||
|
||||
Returns:
|
||||
Exception: TypeError if error.
|
||||
"""
|
||||
type_check(param, (list, np.ndarray), param_name)
|
||||
if isinstance(param, list):
|
||||
param_names = ["pair_{0}".format(i) for i in range(len(param))]
|
||||
type_check_list(param, (tuple,), param_names)
|
||||
for idx, pair in enumerate(param):
|
||||
if not len(pair) == 2:
|
||||
raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format(
|
||||
param_names[idx], len(pair)))
|
||||
column_names = ["element_{0}".format(i) for i in range(len(pair))]
|
||||
type_check_list(pair, (int,), column_names)
|
||||
elif isinstance(param, np.ndarray):
|
||||
if param.ndim != 2:
|
||||
raise ValueError("Input ndarray must be in dimension 2. Got {0}".format(param.ndim))
|
||||
if param.shape[1] != 2:
|
||||
raise ValueError("Each member in {0} must be a pair which means length == 2. Got length {1}".format(
|
||||
param_name, param.shape[1]))
|
||||
if not param.dtype == np.int32:
|
||||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||
param_name, param.dtype))
|
||||
|
||||
|
||||
def check_gnn_list_or_ndarray(param, param_name):
|
||||
"""
|
||||
Check if the input parameter is list or numpy.ndarray.
|
||||
|
|
|
@ -26,9 +26,9 @@ from mindspore._c_dataengine import Tensor
|
|||
from mindspore._c_dataengine import SamplingStrategy as Sampling
|
||||
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
|
||||
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_get_edge_feature, \
|
||||
check_gnn_random_walk
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \
|
||||
check_gnn_get_sampled_neighbors, check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, \
|
||||
check_gnn_get_edge_feature, check_gnn_random_walk
|
||||
|
||||
|
||||
class SamplingStrategy(IntEnum):
|
||||
|
@ -162,6 +162,27 @@ class GraphData:
|
|||
raise Exception("This method is not supported when working mode is server.")
|
||||
return self._graph_data.get_nodes_from_edges(edge_list).as_array()
|
||||
|
||||
@check_gnn_get_edges_from_nodes
|
||||
def get_edges_from_nodes(self, node_list):
|
||||
"""
|
||||
Get edges from the nodes.
|
||||
|
||||
Args:
|
||||
node_list (Union[list[tuple], numpy.ndarray]): The given list of pair nodes ID.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of edgs ID.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)])
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
return self._graph_data.get_edges_from_nodes(node_list).as_array()
|
||||
|
||||
@check_gnn_get_all_neighbors
|
||||
def get_all_neighbors(self, node_list, neighbor_type):
|
||||
"""
|
||||
|
|
|
@ -25,8 +25,8 @@ import numpy as np
|
|||
from mindspore._c_expression import typing
|
||||
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
||||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_num_parallel_workers, \
|
||||
check_columns, check_pos_int32, check_valid_str
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
|
||||
check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
@ -1090,6 +1090,19 @@ def check_gnn_get_nodes_from_edges(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_edges_from_nodes(method):
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_edges_from_nodes` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[node_list], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_gnn_list_of_pair_or_ndarray(node_list, "node_list")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_all_neighbors(method):
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_all_neighbors` function."""
|
||||
|
||||
|
|
|
@ -95,6 +95,21 @@ class MindDataTestGNNGraph : public UT::Common {
|
|||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
std::vector<std::pair<NodeIdType, NodeIdType>> src_dst_list = {{101, 201}, {103, 207}, {108, 208},
|
||||
{110, 201}, {204, 105}, {208, 108}};
|
||||
std::shared_ptr<Tensor> edges;
|
||||
s = graph.GetEdgesFromNodes(src_dst_list, &edges);
|
||||
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(edges->ToString() == "Tensor (shape: <6>, Type: int32)\n[1,9,17,19,31,37]");
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
|
|
|
@ -241,6 +241,18 @@ def test_graphdata_getedgefeature():
|
|||
assert features[1].shape == (40,)
|
||||
|
||||
|
||||
def test_graphdata_getedgesfromnodes():
|
||||
"""
|
||||
Test get edges from nodes
|
||||
"""
|
||||
logger.info('test get_edges_from_nodes\n')
|
||||
g = ds.GraphData(DATASET_FILE)
|
||||
|
||||
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
|
||||
edges = g.get_edges_from_nodes(nodes_pair_list)
|
||||
assert edges.tolist() == [1, 9, 31, 17, 20, 40]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_graphdata_getfullneighbor()
|
||||
test_graphdata_getnodefeature_input_check()
|
||||
|
@ -251,3 +263,4 @@ if __name__ == '__main__':
|
|||
test_graphdata_randomwalkdefault()
|
||||
test_graphdata_randomwalk()
|
||||
test_graphdata_getedgefeature()
|
||||
test_graphdata_getedgesfromnodes()
|
||||
|
|
|
@ -112,6 +112,10 @@ def test_graphdata_distributed():
|
|||
assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
|
||||
|
||||
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)]
|
||||
edges = g.get_edges_from_nodes(nodes_pair_list)
|
||||
assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37]
|
||||
|
||||
batch_num = 2
|
||||
edge_num = g.graph_info()['edge_num'][0]
|
||||
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
|
||||
|
|
Loading…
Reference in New Issue