gnn support weight sampling neighbors.

This commit is contained in:
heleiwang 2021-01-25 10:36:45 +08:00
parent 3708624a25
commit 2dc9ba761c
28 changed files with 282 additions and 76 deletions

View File

@ -65,9 +65,9 @@ PYBIND_REGISTER(
}) })
.def("get_sampled_neighbors", .def("get_sampled_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) { std::vector<gnn::NodeType> neighbor_types, SamplingStrategy strategy) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out));
return out; return out;
}) })
.def("get_neg_sampled_neighbors", .def("get_neg_sampled_neighbors",
@ -114,7 +114,14 @@ PYBIND_REGISTER(
return out; return out;
})) }))
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); .def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); });
}));
PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) {
(void)py::enum_<SamplingStrategy>(*m, "SamplingStrategy", py::arithmetic())
.value("DE_SAMPLING_RANDOM", SamplingStrategy::kRandom)
.value("DE_SAMPLING_EDGE_WEIGHT", SamplingStrategy::kEdgeWeight)
.export_values();
})); }));
} // namespace dataset } // namespace dataset

View File

@ -71,6 +71,9 @@ enum class NormalizeForm {
kNfkd, kNfkd,
}; };
// Possible values for SamplingStrategy
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
// convenience functions for 32bit int bitmask // convenience functions for 32bit int bitmask
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }

View File

@ -35,10 +35,11 @@ class Edge {
// Constructor // Constructor
// @param EdgeIdType id - edge id // @param EdgeIdType id - edge id
// @param EdgeType type - edge type // @param EdgeType type - edge type
// @param WeightType weight - edge weight
// @param std::shared_ptr<Node> src_node - source node // @param std::shared_ptr<Node> src_node - source node
// @param std::shared_ptr<Node> dst_node - destination node // @param std::shared_ptr<Node> dst_node - destination node
Edge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node) Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
: id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} : id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {}
virtual ~Edge() = default; virtual ~Edge() = default;
@ -48,6 +49,9 @@ class Edge {
// @return NodeIdType - Returned edge type // @return NodeIdType - Returned edge type
EdgeType type() const { return type_; } EdgeType type() const { return type_; }
// @return WeightType - Returned edge weight
WeightType weight() const { return weight_; }
// Get the feature of a edge // Get the feature of a edge
// @param FeatureType feature_type - type of feature // @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
@ -77,6 +81,7 @@ class Edge {
protected: protected:
EdgeIdType id_; EdgeIdType id_;
EdgeType type_; EdgeType type_;
WeightType weight_;
std::shared_ptr<Node> src_node_; std::shared_ptr<Node> src_node_;
std::shared_ptr<Node> dst_node_; std::shared_ptr<Node> dst_node_;
}; };

View File

@ -71,6 +71,7 @@ message GnnGraphDataRequestPb {
repeated int32 number = 4; // samples number repeated int32 number = 4; // samples number
TensorPb id_tensor = 5; // input ids ,node id or edge id TensorPb id_tensor = 5; // input ids ,node id or edge id
GnnRandomWalkPb random_walk = 6; GnnRandomWalkPb random_walk = 6;
int32 strategy = 7;
} }
message GnnGraphDataResponsePb { message GnnGraphDataResponsePb {

View File

@ -76,11 +76,13 @@ class GraphData {
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::SamplingStrategy strategy - Sampling strategy
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status The status code returned // @return Status The status code returned
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0; const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) = 0;
// Get negative sampled neighbors. // Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes

View File

@ -137,7 +137,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list
Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request; GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response; GnnGraphDataResponsePb response;
@ -151,6 +152,7 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_
for (const auto &type : neighbor_types) { for (const auto &type : neighbor_types) {
request.add_type(static_cast<google::protobuf::int32>(type)); request.add_type(static_cast<google::protobuf::int32>(type));
} }
request.set_strategy(static_cast<google::protobuf::int32>(strategy));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif #endif
return Status::OK(); return Status::OK();

View File

@ -86,10 +86,12 @@ class GraphDataClient : public GraphData {
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::SamplingStrategy strategy - Sampling strategy
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status The status code returned // @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors. // Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes

View File

@ -171,7 +171,8 @@ Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) {
Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
"The sizes of neighbor_nums and neighbor_types are inconsistent."); "The sizes of neighbor_nums and neighbor_types are inconsistent.");
@ -199,7 +200,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
std::shared_ptr<Node> node; std::shared_ptr<Node> node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node));
std::vector<NodeIdType> out; std::vector<NodeIdType> out;
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out));
neighbors.insert(neighbors.end(), out.begin(), out.end()); neighbors.insert(neighbors.end(), out.begin(), out.end());
} }
} }

View File

@ -80,10 +80,12 @@ class GraphDataImpl : public GraphData {
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop // @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop // @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::SamplingStrategy strategy - Sampling strategy
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status The status code returned // @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override; const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors. // Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes

View File

@ -50,7 +50,7 @@ class GraphDataServer {
enum ServerState state() { return state_; } enum ServerState state() { return state_; }
bool IsStoped() { bool IsStopped() {
if (state_ == kGdsStopped) { if (state_ == kGdsStopped) {
return true; return true;
} else { } else {

View File

@ -78,7 +78,7 @@ grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context,
} }
break; break;
case GraphDataServer::kGdsStopped: case GraphDataServer::kGdsStopped:
response->set_error_msg("Stoped"); response->set_error_msg("Stopped");
break; break;
} }
} else { } else {
@ -222,8 +222,9 @@ Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *re
neighbor_types.resize(request->type().size()); neighbor_types.resize(request->type().size());
std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(),
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); }); [](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
SamplingStrategy strategy = static_cast<SamplingStrategy>(request->strategy());
std::shared_ptr<Tensor> tensor; std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &tensor));
TensorPb *result = response->add_result_data(); TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result)); RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK(); return Status::OK();

View File

@ -39,7 +39,9 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int
row_id_(0), row_id_(0),
shard_reader_(nullptr), shard_reader_(nullptr),
graph_feature_parser_(nullptr), graph_feature_parser_(nullptr),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} required_key_(
{"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}),
optional_key_({{"weight", false}}) {}
Status GraphLoader::GetNodesAndEdges() { Status GraphLoader::GetNodesAndEdges() {
NodeIdMap *n_id_map = &graph_impl_->node_id_map_; NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
@ -62,7 +64,7 @@ Status GraphLoader::GetNodesAndEdges() {
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight()));
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
dq.pop_front(); dq.pop_front();
@ -95,12 +97,18 @@ Status GraphLoader::InitAndLoad() {
graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
mindrecord::json schema = graph_impl_->data_schema_["schema"]; mindrecord::json schema = graph_impl_->data_schema_["schema"];
for (const std::string &key : keys_) { for (const std::string &key : required_key_) {
if (schema.find(key) == schema.end()) { if (schema.find(key) == schema.end()) {
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
} }
} }
for (auto op_key : optional_key_) {
if (schema.find(op_key.first) != schema.end()) {
optional_key_[op_key.first] = true;
}
}
if (graph_impl_->server_mode_) { if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64) #if !defined(_WIN32) && !defined(_WIN64)
int64_t total_blob_size = 0; int64_t total_blob_size = 0;
@ -128,7 +136,11 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
DefaultNodeFeatureMap *default_feature) { DefaultNodeFeatureMap *default_feature) {
NodeIdType node_id = col_jsn["first_id"]; NodeIdType node_id = col_jsn["first_id"];
NodeType node_type = static_cast<NodeType>(col_jsn["type"]); NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
(*node) = std::make_shared<LocalNode>(node_id, node_type); WeightType weight = 1;
if (optional_key_["weight"]) {
weight = col_jsn["weight"];
}
(*node) = std::make_shared<LocalNode>(node_id, node_type, weight);
std::vector<int32_t> indices; std::vector<int32_t> indices;
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) { if (graph_impl_->server_mode_) {
@ -174,9 +186,13 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
EdgeIdType edge_id = col_jsn["first_id"]; EdgeIdType edge_id = col_jsn["first_id"];
EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]); EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1); WeightType edge_weight = 1;
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1); if (optional_key_["weight"]) {
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst); edge_weight = col_jsn["weight"];
}
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1, 1);
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1, 1);
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, edge_weight, src, dst);
std::vector<int32_t> indices; std::vector<int32_t> indices;
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) { if (graph_impl_->server_mode_) {

View File

@ -110,7 +110,8 @@ class GraphLoader {
std::vector<EdgeFeatureMap> e_feature_maps_; std::vector<EdgeFeatureMap> e_feature_maps_;
std::vector<DefaultNodeFeatureMap> default_node_feature_maps_; std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_; std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
const std::vector<std::string> keys_; const std::vector<std::string> required_key_;
std::unordered_map<std::string, bool> optional_key_;
}; };
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset

View File

@ -21,8 +21,9 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node) LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
: Edge(id, type, src_node, dst_node) {} std::shared_ptr<Node> dst_node)
: Edge(id, type, weight, src_node, dst_node) {}
Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = features_.find(feature_type); auto itr = features_.find(feature_type);

View File

@ -34,9 +34,11 @@ class LocalEdge : public Edge {
// Constructor // Constructor
// @param EdgeIdType id - edge id // @param EdgeIdType id - edge id
// @param EdgeType type - edge type // @param EdgeType type - edge type
// @param WeightType weight - edge weight
// @param std::shared_ptr<Node> src_node - source node // @param std::shared_ptr<Node> src_node - source node
// @param std::shared_ptr<Node> dst_node - destination node // @param std::shared_ptr<Node> dst_node - destination node
LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node); LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
std::shared_ptr<Node> dst_node);
~LocalEdge() = default; ~LocalEdge() = default;

View File

@ -16,6 +16,7 @@
#include "minddata/dataset/engine/gnn/local_node.h" #include "minddata/dataset/engine/gnn/local_node.h"
#include <algorithm> #include <algorithm>
#include <random>
#include <string> #include <string>
#include <utility> #include <utility>
@ -26,7 +27,10 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight)
: Node(id, type, weight), rnd_(GetRandomDevice()) {
rnd_.seed(GetSeed());
}
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = features_.find(feature_type); auto itr = features_.find(feature_type);
@ -44,13 +48,13 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
auto itr = neighbor_nodes_.find(neighbor_type); auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) { if (itr != neighbor_nodes_.end()) {
if (exclude_itself) { if (exclude_itself) {
neighbors.resize(itr->second.size()); neighbors.resize(itr->second.first.size());
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(),
[](const std::shared_ptr<Node> node) { return node->id(); }); [](const std::shared_ptr<Node> node) { return node->id(); });
} else { } else {
neighbors.resize(itr->second.size() + 1); neighbors.resize(itr->second.first.size() + 1);
neighbors[0] = id_; neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); }); [](const std::shared_ptr<Node> node) { return node->id(); });
} }
} else { } else {
@ -63,7 +67,7 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
return Status::OK(); return Status::OK();
} }
Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
std::vector<NodeIdType> *out) { std::vector<NodeIdType> *out) {
std::vector<NodeIdType> shuffled_id(neighbors.size()); std::vector<NodeIdType> shuffled_id(neighbors.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0); std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
@ -75,14 +79,33 @@ Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &
return Status::OK(); return Status::OK();
} }
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
const std::vector<WeightType> &weights, int32_t samples_num,
std::vector<NodeIdType> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(),
"The number of neighbors does not match the weight.");
std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end());
for (int32_t i = 0; i < samples_num; ++i) {
NodeIdType index = discrete_dist(rnd_);
out->emplace_back(neighbors[index]->id());
}
return Status::OK();
}
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
std::vector<NodeIdType> *out_neighbors) { std::vector<NodeIdType> *out_neighbors) {
std::vector<NodeIdType> neighbors; std::vector<NodeIdType> neighbors;
neighbors.reserve(samples_num); neighbors.reserve(samples_num);
auto itr = neighbor_nodes_.find(neighbor_type); auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) { if (itr != neighbor_nodes_.end()) {
if (strategy == SamplingStrategy::kRandom) {
while (neighbors.size() < samples_num) { while (neighbors.size() < samples_num) {
RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors));
}
} else if (strategy == SamplingStrategy::kEdgeWeight) {
RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors));
} else {
RETURN_STATUS_UNEXPECTED("Invalid strategy");
} }
} else { } else {
MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
@ -95,12 +118,15 @@ Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_nu
return Status::OK(); return Status::OK();
} }
Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node) { Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) {
auto itr = neighbor_nodes_.find(node->type()); auto itr = neighbor_nodes_.find(node->type());
if (itr != neighbor_nodes_.end()) { if (itr != neighbor_nodes_.end()) {
itr->second.push_back(node); itr->second.first.push_back(node);
itr->second.second.push_back(weight);
} else { } else {
neighbor_nodes_[node->type()] = {node}; std::vector<std::shared_ptr<Node>> nodes = {node};
std::vector<WeightType> weights = {weight};
neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights));
} }
return Status::OK(); return Status::OK();
} }

View File

@ -18,6 +18,7 @@
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/engine/gnn/node.h"
@ -33,7 +34,7 @@ class LocalNode : public Node {
// Constructor // Constructor
// @param NodeIdType id - node id // @param NodeIdType id - node id
// @param NodeType type - node type // @param NodeType type - node type
LocalNode(NodeIdType id, NodeType type); LocalNode(NodeIdType id, NodeType type, WeightType weight);
~LocalNode() = default; ~LocalNode() = default;
@ -53,15 +54,16 @@ class LocalNode : public Node {
// Get the sampled neighbors of a node // Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired // @param int32_t samples_num - Number of neighbors to be acquired
// @param SamplingStrategy strategy - Sampling strategy
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status The status code returned // @return Status The status code returned
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
std::vector<NodeIdType> *out_neighbors) override; std::vector<NodeIdType> *out_neighbors) override;
// Add neighbor of node // Add neighbor of node
// @param std::shared_ptr<Node> node - // @param std::shared_ptr<Node> node -
// @return Status The status code returned // @return Status The status code returned
Status AddNeighbor(const std::shared_ptr<Node> &node) override; Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override;
// Update feature of node // Update feature of node
// @param std::shared_ptr<Feature> feature - // @param std::shared_ptr<Feature> feature -
@ -69,12 +71,16 @@ class LocalNode : public Node {
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override; Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
private: private:
Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num, Status GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
std::vector<NodeIdType> *out);
Status GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
const std::vector<WeightType> &weights, int32_t samples_num,
std::vector<NodeIdType> *out); std::vector<NodeIdType> *out);
std::mt19937 rnd_; std::mt19937 rnd_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_; std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_; std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
}; };
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset

View File

@ -28,6 +28,7 @@ namespace dataset {
namespace gnn { namespace gnn {
using NodeType = int8_t; using NodeType = int8_t;
using NodeIdType = int32_t; using NodeIdType = int32_t;
using WeightType = float;
constexpr NodeIdType kDefaultNodeId = -1; constexpr NodeIdType kDefaultNodeId = -1;
@ -36,7 +37,8 @@ class Node {
// Constructor // Constructor
// @param NodeIdType id - node id // @param NodeIdType id - node id
// @param NodeType type - node type // @param NodeType type - node type
Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} // @param WeightType type - node weight
Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {}
virtual ~Node() = default; virtual ~Node() = default;
@ -46,6 +48,9 @@ class Node {
// @return NodeIdType - Returned node type // @return NodeIdType - Returned node type
NodeType type() const { return type_; } NodeType type() const { return type_; }
// @return WeightType - Returned node weight
WeightType weight() const { return weight_; }
// Get the feature of a node // Get the feature of a node
// @param FeatureType feature_type - type of feature // @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature // @param std::shared_ptr<Feature> *out_feature - Returned feature
@ -62,15 +67,16 @@ class Node {
// Get the sampled neighbors of a node // Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor // @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired // @param int32_t samples_num - Number of neighbors to be acquired
// @param SamplingStrategy strategy - Sampling strategy
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id // @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status The status code returned // @return Status The status code returned
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
std::vector<NodeIdType> *out_neighbors) = 0; std::vector<NodeIdType> *out_neighbors) = 0;
// Add neighbor of node // Add neighbor of node
// @param std::shared_ptr<Node> node - // @param std::shared_ptr<Node> node -
// @return Status The status code returned // @return Status The status code returned
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0; virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0;
// Update feature of node // Update feature of node
// @param std::shared_ptr<Feature> feature - // @param std::shared_ptr<Feature> feature -
@ -80,6 +86,7 @@ class Node {
protected: protected:
NodeIdType id_; NodeIdType id_;
NodeType type_; NodeType type_;
WeightType weight_;
}; };
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset

View File

@ -71,6 +71,9 @@ enum class NormalizeForm {
kNfkd, kNfkd,
}; };
// Possible values for SamplingStrategy
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
// convenience functions for 32bit int bitmask // convenience functions for 32bit int bitmask
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }

View File

@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
from ..core import config from ..core import config
from .cache_client import DatasetCache from .cache_client import DatasetCache
from .datasets import * from .datasets import *
from .graphdata import GraphData from .graphdata import GraphData, SamplingStrategy
from .iterators import * from .iterators import *
from .samplers import * from .samplers import *
from .serializer_deserializer import compare, deserialize, serialize, show from .serializer_deserializer import compare, deserialize, serialize, show

View File

@ -18,10 +18,12 @@ and provides operations related to graph data.
""" """
import atexit import atexit
import time import time
from enum import IntEnum
import numpy as np import numpy as np
from mindspore._c_dataengine import GraphDataClient from mindspore._c_dataengine import GraphDataClient
from mindspore._c_dataengine import GraphDataServer from mindspore._c_dataengine import GraphDataServer
from mindspore._c_dataengine import Tensor from mindspore._c_dataengine import Tensor
from mindspore._c_dataengine import SamplingStrategy as Sampling
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
@ -29,6 +31,17 @@ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_
check_gnn_random_walk check_gnn_random_walk
class SamplingStrategy(IntEnum):
RANDOM = 0
EDGE_WEIGHT = 1
DE_C_INTER_SAMPLING_STRATEGY = {
SamplingStrategy.RANDOM: Sampling.DE_SAMPLING_RANDOM,
SamplingStrategy.EDGE_WEIGHT: Sampling.DE_SAMPLING_EDGE_WEIGHT,
}
class GraphData: class GraphData:
""" """
Reads the graph dataset used for GNN training from the shared file and database. Reads the graph dataset used for GNN training from the shared file and database.
@ -86,7 +99,7 @@ class GraphData:
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
atexit.register(stop) atexit.register(stop)
try: try:
while self._graph_data.is_stoped() is not True: while self._graph_data.is_stopped() is not True:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Graph data server receives KeyboardInterrupt.") raise Exception("Graph data server receives KeyboardInterrupt.")
@ -185,7 +198,7 @@ class GraphData:
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
@check_gnn_get_sampled_neighbors @check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM):
""" """
Get sampled neighbor information. Get sampled neighbor information.
@ -199,6 +212,11 @@ class GraphData:
node_list (Union[list, numpy.ndarray]): The given list of nodes. node_list (Union[list, numpy.ndarray]): The given list of nodes.
neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop. neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop.
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.
strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM).
It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT].
- SamplingStrategy.RANDOM, random sampling with replacement.
- SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability.
Returns: Returns:
numpy.ndarray, array of neighbors. numpy.ndarray, array of neighbors.
@ -215,10 +233,12 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray.
""" """
if not isinstance(strategy, SamplingStrategy):
raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.")
if self._working_mode == 'server': if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server.") raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_sampled_neighbors( return self._graph_data.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array() node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array()
@check_gnn_get_neg_sampled_neighbors @check_gnn_get_neg_sampled_neighbors
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):

View File

@ -1114,7 +1114,7 @@ def check_gnn_get_sampled_neighbors(method):
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs) [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(node_list, 'node_list') check_gnn_list_or_ndarray(node_list, 'node_list')

View File

@ -37,6 +37,7 @@ class GraphMapSchema:
"second_id": {"type": "int64"}, "second_id": {"type": "int64"},
"third_id": {"type": "int64"}, "third_id": {"type": "int64"},
"type": {"type": "int32"}, "type": {"type": "int32"},
"weight": {"type": "float32"},
"attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge "attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge
"node_feature_index": {"type": "int32", "shape": [-1]}, "node_feature_index": {"type": "int32", "shape": [-1]},
"edge_feature_index": {"type": "int32", "shape": [-1]} "edge_feature_index": {"type": "int32", "shape": [-1]}
@ -91,8 +92,11 @@ class GraphMapSchema:
logger.info("node cannot be None.") logger.info("node cannot be None.")
raise ValueError("node cannot be None.") raise ValueError("node cannot be None.")
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n',
"node_feature_index": []} "type": node["type"], "node_feature_index": []}
if "weight" in node:
node_graph["weight"] = node["weight"]
for i in range(self.num_node_features): for i in range(self.num_node_features):
k = i + 1 k = i + 1
node_field_key = 'feature_' + str(k) node_field_key = 'feature_' + str(k)
@ -129,8 +133,11 @@ class GraphMapSchema:
logger.info("edge cannot be None.") logger.info("edge cannot be None.")
raise ValueError("edge cannot be None.") raise ValueError("edge cannot be None.")
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0,
"type": edge["type"], "edge_feature_index": []} "attribute": 'e', "type": edge["type"], "edge_feature_index": []}
if "weight" in edge:
edge_graph["weight"] = edge["weight"]
for i in range(self.num_edge_features): for i in range(self.num_edge_features):
k = i + 1 k = i + 1

View File

@ -15,6 +15,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <map>
#include <memory> #include <memory>
#include <unordered_set> #include <unordered_set>
@ -38,6 +39,60 @@ using namespace mindspore::dataset::gnn;
class MindDataTestGNNGraph : public UT::Common { class MindDataTestGNNGraph : public UT::Common {
protected: protected:
MindDataTestGNNGraph() = default; MindDataTestGNNGraph() = default;
using NumNeighborsMap = std::map<NodeIdType, uint32_t>;
using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>;
void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) {
auto shape_vec = neighbors->shape().AsVector();
uint32_t num_members = 1;
for (size_t i = 1; i < shape_vec.size(); ++i) {
num_members *= shape_vec[i];
}
uint32_t index = 0;
NodeIdType src_node = 0;
for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>();
++node_itr, ++index) {
if (index % num_members == 0) {
src_node = *node_itr;
continue;
}
auto src_node_itr = node_neighbors.find(src_node);
if (src_node_itr == node_neighbors.end()) {
node_neighbors[src_node] = {{*node_itr, 1}};
} else {
auto nei_itr = src_node_itr->second.find(*node_itr);
if (nei_itr == src_node_itr->second.end()) {
src_node_itr->second[*node_itr] = 1;
} else {
src_node_itr->second[*node_itr] += 1;
}
}
}
}
void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights,
float deviation_ratio = 0.1) {
EXPECT_EQ(number_neighbors.size(), weights.size());
int index = 0;
uint32_t pre_num = 0;
WeightType pre_weight = 1;
for (auto neighbor : number_neighbors) {
if (pre_num != 0) {
float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]);
float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second);
float target_upper = target_ratio * (1 + deviation_ratio);
float target_lower = target_ratio * (1 - deviation_ratio);
MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio)
<< " target_upper:" << std::to_string(target_upper)
<< " target_lower:" << std::to_string(target_lower);
EXPECT_LE(current_ratio, target_upper);
EXPECT_GE(current_ratio, target_lower);
}
pre_num = neighbor.second;
pre_weight = weights[index];
++index;
}
}
}; };
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
@ -131,44 +186,75 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
std::shared_ptr<Tensor> neighbors; std::shared_ptr<Tensor> neighbors;
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); {
MS_LOG(INFO) << "Test random sampling.";
NodeNeighborsMap number_neighbors;
int count = 0;
while (count < 1000) {
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
ParsingNeighbors(neighbors, number_neighbors);
++count;
}
CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1});
}
{
MS_LOG(INFO) << "Test edge weight sampling.";
NodeNeighborsMap number_neighbors;
int count = 0;
while (count < 1000) {
neighbors.reset();
s =
graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
ParsingNeighbors(neighbors, number_neighbors);
++count;
}
CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8});
}
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors); s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors); s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
std::string::npos); std::string::npos);
neighbors.reset(); neighbors.reset();
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
} }

View File

@ -17,6 +17,7 @@ import pytest
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from mindspore.dataset.engine import SamplingStrategy
DATASET_FILE = "../data/mindrecord/testGraphData/testdata" DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
@ -97,7 +98,10 @@ def test_graphdata_getsampledneighbors():
nodes = g.get_nodes_from_edges(edges) nodes = g.get_nodes_from_edges(edges)
assert len(nodes) == 40 assert len(nodes) == 40
neighbor = g.get_sampled_neighbors( neighbor = g.get_sampled_neighbors(
np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM)
assert neighbor.shape == (10, 9)
neighbor = g.get_sampled_neighbors(
np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT)
assert neighbor.shape == (10, 9) assert neighbor.shape == (10, 9)

View File

@ -20,6 +20,7 @@ from multiprocessing import Process
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from mindspore.dataset.engine import SamplingStrategy
DATASET_FILE = "../data/mindrecord/testGraphData/testdata" DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
@ -68,9 +69,9 @@ class GNNGraphDataset():
neg_nodes = self.g.get_neg_sampled_neighbors( neg_nodes = self.g.get_neg_sampled_neighbors(
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
2, 2], neighbor_types=[2, 1]) 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
neg_nodes_neighbors = self.g.get_sampled_neighbors( neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
nodes_neighbors_features = self.g.get_node_feature( nodes_neighbors_features = self.g.get_node_feature(
node_list=nodes_neighbors, feature_types=[2, 3]) node_list=nodes_neighbors, feature_types=[2, 3])
neg_neighbors_features = self.g.get_node_feature( neg_neighbors_features = self.g.get_node_feature(