From 2dc9ba761c575c720b8713ca312fa143664ce029 Mon Sep 17 00:00:00 2001 From: heleiwang Date: Mon, 25 Jan 2021 10:36:45 +0800 Subject: [PATCH] gnn support weight sampling neighbors. --- .../bindings/dataset/engine/gnn/bindings.cc | 13 ++- .../ccsrc/minddata/dataset/core/constants.h | 3 + .../ccsrc/minddata/dataset/engine/gnn/edge.h | 9 +- .../dataset/engine/gnn/gnn_graph_data.proto | 1 + .../minddata/dataset/engine/gnn/graph_data.h | 6 +- .../dataset/engine/gnn/graph_data_client.cc | 4 +- .../dataset/engine/gnn/graph_data_client.h | 6 +- .../dataset/engine/gnn/graph_data_impl.cc | 5 +- .../dataset/engine/gnn/graph_data_impl.h | 8 +- .../dataset/engine/gnn/graph_data_server.h | 2 +- .../engine/gnn/graph_data_service_impl.cc | 5 +- .../dataset/engine/gnn/graph_loader.cc | 30 +++-- .../dataset/engine/gnn/graph_loader.h | 3 +- .../minddata/dataset/engine/gnn/local_edge.cc | 5 +- .../minddata/dataset/engine/gnn/local_edge.h | 4 +- .../minddata/dataset/engine/gnn/local_node.cc | 52 ++++++--- .../minddata/dataset/engine/gnn/local_node.h | 18 ++- .../ccsrc/minddata/dataset/engine/gnn/node.h | 13 ++- .../minddata/dataset/include/constants.h | 3 + mindspore/dataset/engine/__init__.py | 2 +- mindspore/dataset/engine/graphdata.py | 28 ++++- mindspore/dataset/engine/validators.py | 2 +- .../graph_to_mindrecord/graph_map_schema.py | 15 ++- tests/ut/cpp/dataset/gnn_graph_test.cc | 108 ++++++++++++++++-- .../ut/data/mindrecord/testGraphData/testdata | Bin 52682 -> 70026 bytes .../data/mindrecord/testGraphData/testdata.db | Bin 16384 -> 16384 bytes tests/ut/python/dataset/test_graphdata.py | 6 +- .../dataset/test_graphdata_distributed.py | 7 +- 28 files changed, 282 insertions(+), 76 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc index 936ba2804e9..5e8c5e3a1c2 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc @@ -65,9 +65,9 @@ PYBIND_REGISTER( }) .def("get_sampled_neighbors", [](gnn::GraphData &g, std::vector node_list, std::vector neighbor_nums, - std::vector neighbor_types) { + std::vector neighbor_types, SamplingStrategy strategy) { std::shared_ptr out; - THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); + THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out)); return out; }) .def("get_neg_sampled_neighbors", @@ -114,8 +114,15 @@ PYBIND_REGISTER( return out; })) .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) - .def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); + .def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); }); })); +PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) { + (void)py::enum_(*m, "SamplingStrategy", py::arithmetic()) + .value("DE_SAMPLING_RANDOM", SamplingStrategy::kRandom) + .value("DE_SAMPLING_EDGE_WEIGHT", SamplingStrategy::kEdgeWeight) + .export_values(); + })); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index 480eb682f0e..d749827c880 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -71,6 +71,9 @@ enum class NormalizeForm { kNfkd, }; +// Possible values for SamplingStrategy +enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 }; + // convenience functions for 32bit int bitmask inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h index b11c20bf3a5..b80e73de3f4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h @@ -35,10 +35,11 @@ class Edge { // Constructor // @param EdgeIdType id - edge id // @param EdgeType type - edge type + // @param WeightType weight - edge weight // @param std::shared_ptr src_node - source node // @param std::shared_ptr dst_node - destination node - Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} + Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr src_node, std::shared_ptr dst_node) + : id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {} virtual ~Edge() = default; @@ -48,6 +49,9 @@ class Edge { // @return NodeIdType - Returned edge type EdgeType type() const { return type_; } + // @return WeightType - Returned edge weight + WeightType weight() const { return weight_; } + // Get the feature of a edge // @param FeatureType feature_type - type of feature // @param std::shared_ptr *out_feature - Returned feature @@ -77,6 +81,7 @@ class Edge { protected: EdgeIdType id_; EdgeType type_; + WeightType weight_; std::shared_ptr src_node_; std::shared_ptr dst_node_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto index 1342d047cc5..f95a823eb7d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto @@ -71,6 +71,7 @@ message GnnGraphDataRequestPb { repeated int32 number = 4; // samples number TensorPb id_tensor = 5; // input ids ,node id or edge id GnnRandomWalkPb random_walk = 6; + int32 strategy = 7; } message GnnGraphDataResponsePb { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h index a790c363f38..c50bf194dd5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h @@ -76,11 +76,13 @@ class GraphData { // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::SamplingStrategy strategy - Sampling strategy // @param std::shared_ptr *out - Returned neighbor's id. // @return Status The status code returned virtual Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) = 0; + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) = 0; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -95,7 +97,7 @@ class GraphData { // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param float step_away_param - in out hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc index 6fdde154268..a9f618ccbe9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc @@ -137,7 +137,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector &node_list Status GraphDataClient::GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) { #if !defined(_WIN32) && !defined(_WIN64) GnnGraphDataRequestPb request; GnnGraphDataResponsePb response; @@ -151,6 +152,7 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector &node_ for (const auto &type : neighbor_types) { request.add_type(static_cast(type)); } + request.set_strategy(static_cast(strategy)); RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); #endif return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h index 36cb9169a42..0e8d08f11bf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h @@ -86,10 +86,12 @@ class GraphDataClient : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::SamplingStrategy strategy - Sampling strategy // @param std::shared_ptr *out - Returned neighbor's id. // @return Status The status code returned Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) override; + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) override; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -104,7 +106,7 @@ class GraphDataClient : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param float step_away_param - in out hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc index b97cad0ffaa..70a090bb97b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc @@ -171,7 +171,8 @@ Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) { Status GraphDataImpl::GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), "The sizes of neighbor_nums and neighbor_types are inconsistent."); @@ -199,7 +200,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector &node_li std::shared_ptr node; RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); std::vector out; - RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); + RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out)); neighbors.insert(neighbors.end(), out.begin(), out.end()); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h index 6b6fc05d475..b5db50768b5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h @@ -80,10 +80,12 @@ class GraphDataImpl : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::SamplingStrategy strategy - Sampling strategy // @param std::shared_ptr *out - Returned neighbor's id. // @return Status The status code returned Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) override; + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) override; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -98,7 +100,7 @@ class GraphDataImpl : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param float step_away_param - in out hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status The status code returned @@ -194,7 +196,7 @@ class GraphDataImpl : public GraphData { std::vector node_list_; std::vector meta_path_; float step_home_param_; // Return hyper parameter. Default is 1.0 - float step_away_param_; // Inout hyper parameter. Default is 1.0 + float step_away_param_; // In out hyper parameter. Default is 1.0 NodeIdType default_node_; int32_t num_walks_; // Number of walks per source. Default is 1 diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h index 2bf1ad2b51e..49c57a1c3e1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h @@ -50,7 +50,7 @@ class GraphDataServer { enum ServerState state() { return state_; } - bool IsStoped() { + bool IsStopped() { if (state_ == kGdsStopped) { return true; } else { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc index 3fe4c96ce81..04b930bb557 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc @@ -78,7 +78,7 @@ grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context, } break; case GraphDataServer::kGdsStopped: - response->set_error_msg("Stoped"); + response->set_error_msg("Stopped"); break; } } else { @@ -222,8 +222,9 @@ Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *re neighbor_types.resize(request->type().size()); std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), [](const google::protobuf::int32 type) { return static_cast(type); }); + SamplingStrategy strategy = static_cast(request->strategy()); std::shared_ptr tensor; - RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); + RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &tensor)); TensorPb *result = response->add_result_data(); RETURN_IF_NOT_OK(TensorToPb(tensor, result)); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 1d043ced745..16dcfa4d3a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -39,7 +39,9 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int row_id_(0), shard_reader_(nullptr), graph_feature_parser_(nullptr), - keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} + required_key_( + {"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}), + optional_key_({{"weight", false}}) {} Status GraphLoader::GetNodesAndEdges() { NodeIdMap *n_id_map = &graph_impl_->node_id_map_; @@ -62,7 +64,7 @@ Status GraphLoader::GetNodesAndEdges() { CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); - RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); + RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); dq.pop_front(); @@ -95,12 +97,18 @@ Status GraphLoader::InitAndLoad() { graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); mindrecord::json schema = graph_impl_->data_schema_["schema"]; - for (const std::string &key : keys_) { + for (const std::string &key : required_key_) { if (schema.find(key) == schema.end()) { RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); } } + for (auto op_key : optional_key_) { + if (schema.find(op_key.first) != schema.end()) { + optional_key_[op_key.first] = true; + } + } + if (graph_impl_->server_mode_) { #if !defined(_WIN32) && !defined(_WIN64) int64_t total_blob_size = 0; @@ -128,7 +136,11 @@ Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrec DefaultNodeFeatureMap *default_feature) { NodeIdType node_id = col_jsn["first_id"]; NodeType node_type = static_cast(col_jsn["type"]); - (*node) = std::make_shared(node_id, node_type); + WeightType weight = 1; + if (optional_key_["weight"]) { + weight = col_jsn["weight"]; + } + (*node) = std::make_shared(node_id, node_type, weight); std::vector indices; RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); if (graph_impl_->server_mode_) { @@ -174,9 +186,13 @@ Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrec EdgeIdType edge_id = col_jsn["first_id"]; EdgeType edge_type = static_cast(col_jsn["type"]); NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; - std::shared_ptr src = std::make_shared(src_id, -1); - std::shared_ptr dst = std::make_shared(dst_id, -1); - (*edge) = std::make_shared(edge_id, edge_type, src, dst); + WeightType edge_weight = 1; + if (optional_key_["weight"]) { + edge_weight = col_jsn["weight"]; + } + std::shared_ptr src = std::make_shared(src_id, -1, 1); + std::shared_ptr dst = std::make_shared(dst_id, -1, 1); + (*edge) = std::make_shared(edge_id, edge_type, edge_weight, src, dst); std::vector indices; RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); if (graph_impl_->server_mode_) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h index 7b397547fb5..f3a7fd0edff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h @@ -110,7 +110,8 @@ class GraphLoader { std::vector e_feature_maps_; std::vector default_node_feature_maps_; std::vector default_edge_feature_maps_; - const std::vector keys_; + const std::vector required_key_; + std::unordered_map optional_key_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc index d20be6e318e..27eb508ff6d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc @@ -21,8 +21,9 @@ namespace mindspore { namespace dataset { namespace gnn { -LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : Edge(id, type, src_node, dst_node) {} +LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr src_node, + std::shared_ptr dst_node) + : Edge(id, type, weight, src_node, dst_node) {} Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = features_.find(feature_type); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h index 9c365723b71..fa860fb9297 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h @@ -34,9 +34,11 @@ class LocalEdge : public Edge { // Constructor // @param EdgeIdType id - edge id // @param EdgeType type - edge type + // @param WeightType weight - edge weight // @param std::shared_ptr src_node - source node // @param std::shared_ptr dst_node - destination node - LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); + LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr src_node, + std::shared_ptr dst_node); ~LocalEdge() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc index 8eaf9bb7163..bd7114f571a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc @@ -16,6 +16,7 @@ #include "minddata/dataset/engine/gnn/local_node.h" #include +#include #include #include @@ -26,7 +27,10 @@ namespace mindspore { namespace dataset { namespace gnn { -LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } +LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight) + : Node(id, type, weight), rnd_(GetRandomDevice()) { + rnd_.seed(GetSeed()); +} Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = features_.find(feature_type); @@ -44,13 +48,13 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vectorsecond.size()); - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), + neighbors.resize(itr->second.first.size()); + std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(), [](const std::shared_ptr node) { return node->id(); }); } else { - neighbors.resize(itr->second.size() + 1); + neighbors.resize(itr->second.first.size() + 1); neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1, [](const std::shared_ptr node) { return node->id(); }); } } else { @@ -63,8 +67,8 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector> &neighbors, int32_t samples_num, - std::vector *out) { +Status LocalNode::GetRandomSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out) { std::vector shuffled_id(neighbors.size()); std::iota(shuffled_id.begin(), shuffled_id.end(), 0); std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); @@ -75,14 +79,33 @@ Status LocalNode::GetSampledNeighbors(const std::vector> & return Status::OK(); } -Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, +Status LocalNode::GetWeightSampledNeighbors(const std::vector> &neighbors, + const std::vector &weights, int32_t samples_num, + std::vector *out) { + CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(), + "The number of neighbors does not match the weight."); + std::discrete_distribution discrete_dist(weights.begin(), weights.end()); + for (int32_t i = 0; i < samples_num; ++i) { + NodeIdType index = discrete_dist(rnd_); + out->emplace_back(neighbors[index]->id()); + } + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, std::vector *out_neighbors) { std::vector neighbors; neighbors.reserve(samples_num); auto itr = neighbor_nodes_.find(neighbor_type); if (itr != neighbor_nodes_.end()) { - while (neighbors.size() < samples_num) { - RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); + if (strategy == SamplingStrategy::kRandom) { + while (neighbors.size() < samples_num) { + RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors)); + } + } else if (strategy == SamplingStrategy::kEdgeWeight) { + RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors)); + } else { + RETURN_STATUS_UNEXPECTED("Invalid strategy"); } } else { MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; @@ -95,12 +118,15 @@ Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_nu return Status::OK(); } -Status LocalNode::AddNeighbor(const std::shared_ptr &node) { +Status LocalNode::AddNeighbor(const std::shared_ptr &node, const WeightType &weight) { auto itr = neighbor_nodes_.find(node->type()); if (itr != neighbor_nodes_.end()) { - itr->second.push_back(node); + itr->second.first.push_back(node); + itr->second.second.push_back(weight); } else { - neighbor_nodes_[node->type()] = {node}; + std::vector> nodes = {node}; + std::vector weights = {weight}; + neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights)); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h index 75ba021ee28..7ec030556a8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -18,6 +18,7 @@ #include #include +#include #include #include "minddata/dataset/engine/gnn/node.h" @@ -33,7 +34,7 @@ class LocalNode : public Node { // Constructor // @param NodeIdType id - node id // @param NodeType type - node type - LocalNode(NodeIdType id, NodeType type); + LocalNode(NodeIdType id, NodeType type, WeightType weight); ~LocalNode() = default; @@ -53,15 +54,16 @@ class LocalNode : public Node { // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor // @param int32_t samples_num - Number of neighbors to be acquired + // @param SamplingStrategy strategy - Sampling strategy // @param std::vector *out_neighbors - Returned neighbors id // @return Status The status code returned - Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, std::vector *out_neighbors) override; // Add neighbor of node // @param std::shared_ptr node - // @return Status The status code returned - Status AddNeighbor(const std::shared_ptr &node) override; + Status AddNeighbor(const std::shared_ptr &node, const WeightType &) override; // Update feature of node // @param std::shared_ptr feature - @@ -69,12 +71,16 @@ class LocalNode : public Node { Status UpdateFeature(const std::shared_ptr &feature) override; private: - Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, - std::vector *out); + Status GetRandomSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out); + + Status GetWeightSampledNeighbors(const std::vector> &neighbors, + const std::vector &weights, int32_t samples_num, + std::vector *out); std::mt19937 rnd_; std::unordered_map> features_; - std::unordered_map>> neighbor_nodes_; + std::unordered_map>, std::vector>> neighbor_nodes_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h index 1f5aaf55a8b..3382df5c244 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -28,6 +28,7 @@ namespace dataset { namespace gnn { using NodeType = int8_t; using NodeIdType = int32_t; +using WeightType = float; constexpr NodeIdType kDefaultNodeId = -1; @@ -36,7 +37,8 @@ class Node { // Constructor // @param NodeIdType id - node id // @param NodeType type - node type - Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} + // @param WeightType type - node weight + Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {} virtual ~Node() = default; @@ -46,6 +48,9 @@ class Node { // @return NodeIdType - Returned node type NodeType type() const { return type_; } + // @return WeightType - Returned node weight + WeightType weight() const { return weight_; } + // Get the feature of a node // @param FeatureType feature_type - type of feature // @param std::shared_ptr *out_feature - Returned feature @@ -62,15 +67,16 @@ class Node { // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor // @param int32_t samples_num - Number of neighbors to be acquired + // @param SamplingStrategy strategy - Sampling strategy // @param std::vector *out_neighbors - Returned neighbors id // @return Status The status code returned - virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, std::vector *out_neighbors) = 0; // Add neighbor of node // @param std::shared_ptr node - // @return Status The status code returned - virtual Status AddNeighbor(const std::shared_ptr &node) = 0; + virtual Status AddNeighbor(const std::shared_ptr &node, const WeightType &weight) = 0; // Update feature of node // @param std::shared_ptr feature - @@ -80,6 +86,7 @@ class Node { protected: NodeIdType id_; NodeType type_; + WeightType weight_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/include/constants.h b/mindspore/ccsrc/minddata/dataset/include/constants.h index e5474847f4a..2f34ea573dc 100644 --- a/mindspore/ccsrc/minddata/dataset/include/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/constants.h @@ -71,6 +71,9 @@ enum class NormalizeForm { kNfkd, }; +// Possible values for SamplingStrategy +enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 }; + // convenience functions for 32bit int bitmask inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 7609d2266ec..b2ab05150eb 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip. from ..core import config from .cache_client import DatasetCache from .datasets import * -from .graphdata import GraphData +from .graphdata import GraphData, SamplingStrategy from .iterators import * from .samplers import * from .serializer_deserializer import compare, deserialize, serialize, show diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 58fe40f47d9..68afda3f3be 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -18,10 +18,12 @@ and provides operations related to graph data. """ import atexit import time +from enum import IntEnum import numpy as np from mindspore._c_dataengine import GraphDataClient from mindspore._c_dataengine import GraphDataServer from mindspore._c_dataengine import Tensor +from mindspore._c_dataengine import SamplingStrategy as Sampling from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ @@ -29,6 +31,17 @@ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_ check_gnn_random_walk +class SamplingStrategy(IntEnum): + RANDOM = 0 + EDGE_WEIGHT = 1 + + +DE_C_INTER_SAMPLING_STRATEGY = { + SamplingStrategy.RANDOM: Sampling.DE_SAMPLING_RANDOM, + SamplingStrategy.EDGE_WEIGHT: Sampling.DE_SAMPLING_EDGE_WEIGHT, +} + + class GraphData: """ Reads the graph dataset used for GNN training from the shared file and database. @@ -86,7 +99,7 @@ class GraphData: dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) atexit.register(stop) try: - while self._graph_data.is_stoped() is not True: + while self._graph_data.is_stopped() is not True: time.sleep(1) except KeyboardInterrupt: raise Exception("Graph data server receives KeyboardInterrupt.") @@ -185,7 +198,7 @@ class GraphData: return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() @check_gnn_get_sampled_neighbors - def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): + def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM): """ Get sampled neighbor information. @@ -199,6 +212,11 @@ class GraphData: node_list (Union[list, numpy.ndarray]): The given list of nodes. neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop. neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. + strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM). + It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]. + + - SamplingStrategy.RANDOM, random sampling with replacement. + - SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability. Returns: numpy.ndarray, array of neighbors. @@ -215,10 +233,12 @@ class GraphData: TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray. """ + if not isinstance(strategy, SamplingStrategy): + raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.") if self._working_mode == 'server': raise Exception("This method is not supported when working mode is server.") return self._graph_data.get_sampled_neighbors( - node_list, neighbor_nums, neighbor_types).as_array() + node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array() @check_gnn_get_neg_sampled_neighbors def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): @@ -342,7 +362,7 @@ class GraphData: target_nodes (list[int]): Start node list in random walk meta_path (list[int]): node type for each walk step step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0). - step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0). + step_away_param (float, optional): in out hyper parameter in node2vec algorithm (Default = 1.0). default_node (int, optional): default node if no more neighbors found (Default = -1). A default value of -1 indicates that no node is given. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 724207d5aea..39af3cbb384 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1114,7 +1114,7 @@ def check_gnn_get_sampled_neighbors(method): @wraps(method) def new_method(self, *args, **kwargs): - [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs) + [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs) check_gnn_list_or_ndarray(node_list, 'node_list') diff --git a/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py b/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py index 1da1ced2f7d..1ee1fa343ff 100644 --- a/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py +++ b/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py @@ -37,6 +37,7 @@ class GraphMapSchema: "second_id": {"type": "int64"}, "third_id": {"type": "int64"}, "type": {"type": "int32"}, + "weight": {"type": "float32"}, "attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge "node_feature_index": {"type": "int32", "shape": [-1]}, "edge_feature_index": {"type": "int32", "shape": [-1]} @@ -91,8 +92,11 @@ class GraphMapSchema: logger.info("node cannot be None.") raise ValueError("node cannot be None.") - node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], - "node_feature_index": []} + node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n', + "type": node["type"], "node_feature_index": []} + if "weight" in node: + node_graph["weight"] = node["weight"] + for i in range(self.num_node_features): k = i + 1 node_field_key = 'feature_' + str(k) @@ -129,8 +133,11 @@ class GraphMapSchema: logger.info("edge cannot be None.") raise ValueError("edge cannot be None.") - edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', - "type": edge["type"], "edge_feature_index": []} + edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0, + "attribute": 'e', "type": edge["type"], "edge_feature_index": []} + + if "weight" in edge: + edge_graph["weight"] = edge["weight"] for i in range(self.num_edge_features): k = i + 1 diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 66ad0a6e857..521ffff6342 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -15,6 +15,7 @@ */ #include #include +#include #include #include @@ -38,6 +39,60 @@ using namespace mindspore::dataset::gnn; class MindDataTestGNNGraph : public UT::Common { protected: MindDataTestGNNGraph() = default; + + using NumNeighborsMap = std::map; + using NodeNeighborsMap = std::map; + void ParsingNeighbors(const std::shared_ptr &neighbors, NodeNeighborsMap &node_neighbors) { + auto shape_vec = neighbors->shape().AsVector(); + uint32_t num_members = 1; + for (size_t i = 1; i < shape_vec.size(); ++i) { + num_members *= shape_vec[i]; + } + uint32_t index = 0; + NodeIdType src_node = 0; + for (auto node_itr = neighbors->begin(); node_itr != neighbors->end(); + ++node_itr, ++index) { + if (index % num_members == 0) { + src_node = *node_itr; + continue; + } + auto src_node_itr = node_neighbors.find(src_node); + if (src_node_itr == node_neighbors.end()) { + node_neighbors[src_node] = {{*node_itr, 1}}; + } else { + auto nei_itr = src_node_itr->second.find(*node_itr); + if (nei_itr == src_node_itr->second.end()) { + src_node_itr->second[*node_itr] = 1; + } else { + src_node_itr->second[*node_itr] += 1; + } + } + } + } + + void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector &weights, + float deviation_ratio = 0.1) { + EXPECT_EQ(number_neighbors.size(), weights.size()); + int index = 0; + uint32_t pre_num = 0; + WeightType pre_weight = 1; + for (auto neighbor : number_neighbors) { + if (pre_num != 0) { + float target_ratio = static_cast(pre_weight) / static_cast(weights[index]); + float current_ratio = static_cast(pre_num) / static_cast(neighbor.second); + float target_upper = target_ratio * (1 + deviation_ratio); + float target_lower = target_ratio * (1 - deviation_ratio); + MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio) + << " target_upper:" << std::to_string(target_upper) + << " target_lower:" << std::to_string(target_lower); + EXPECT_LE(current_ratio, target_upper); + EXPECT_GE(current_ratio, target_lower); + } + pre_num = neighbor.second; + pre_weight = weights[index]; + ++index; + } + } }; TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { @@ -131,44 +186,75 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); std::shared_ptr neighbors; - s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); - EXPECT_TRUE(s.IsOk()); - EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + { + MS_LOG(INFO) << "Test random sampling."; + NodeNeighborsMap number_neighbors; + int count = 0; + while (count < 1000) { + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + ParsingNeighbors(neighbors, number_neighbors); + ++count; + } + CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1}); + } + + { + MS_LOG(INFO) << "Test edge weight sampling."; + NodeNeighborsMap number_neighbors; + int count = 0; + while (count < 1000) { + neighbors.reset(); + s = + graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + ParsingNeighbors(neighbors, number_neighbors); + ++count; + } + CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8}); + } neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); neighbors.reset(); s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, - {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); + {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); neighbors.reset(); - s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); } diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata index ad97dbae98de88cd441cf63134099a059eb1c6aa..a52fb92282eb5e6201b7b0d23df9fd1b33172d10 100644 GIT binary patch literal 70026 zcmeI5O>f*p7{{Gtlk^QJv_K05f+>Z=f|59DX(VV`5T{B=fP_$mmQB{)b-OQ-?NCY- zi4!NjKypJuReRSFJ5Oob?V{rx0LMe%=rI*p5Hw4#K|Rl z)7J3Xv|Dby;RdzP|H-wNoy801KdRb(x8eR+^L)462<_!-oAQ9YY}cbG@K-k?*RI;( zTFdR!YktFCp5LmbS9*RBMszZ7Vd$>)x(%I`QOghX7T#LV5AdezH(OCUyzi>^dc8?_ zkknlgR?xdS%<%ktJQ1hM-pAxFEX*g95j9@6uU!4&O6{}DUtcEbRINtGlRK$mGw5xs zlL3a960F)i&kNm%!X!b2t$Lg*oe86Q5b;oku;thice)b`N@z<<7n4aA+LxEVDKj#L zPpzr7US~aU!_e<_*#ymLSUp3|tL@&ZC^_405MrKNk2V6gwm=Xn1B>Z_Va1ZcZm*%< zrVP;AlmRC_u#_H%g-J#YH(VpT+93l0DOR#C&7i*Cs<9MAE}PQ%Sl+Z5@{k;vZqwdM zC9+_XC1jgv#a#Pff%vfXem8O!N0u+{x7=&@>RC6$;pchqV=Ce${^ysdzrrg&TxGdV zxXKOAe)(}6%JQYBTdoH~w4w|ey!&L$4bMH_aUL@3G2#xsSzF?e*Bhj)!qBD7z>eH7 zYSg2;P3j@t10p|+{57(yyU}h}?H}BL9w}0p&WH2%R@~L`9RxrC1V8`;KmY_l00ck) z1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`; zKmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l z00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck)1V8`;KmY_l00ck) z1V8`;KmY_l;Ashj>J2mG;{+e7@OPFK&pc)$N6G6n`4Vfb%Anu>m`%rvlSx(d5Lq*2 zO-xQ5%$kZP2E_X6CyMo?@`-)0o{=>%PModE&i1xyJyouCl8ZG2u39Hil>>Z&^R{SV zTvGB(N~IKQ>XhV0fS*>a%LShJk`OA6o9HA}uGR;H4W;&w7QRe4snsU}A4=oQ;m zO0pekU^Fdy#kSKW*%HWV^HL$tlw?ctN|fVWKRnDe$#yoGe8zh_(zK2!wuefx9cftV z-X1Q=cBEOU*&Zp$cBE0M2k2-?ww9G@Qfeh18?o)My=O0R9VTP)V-$^lk~}J9FHG(r zuc}1Wr^lJ5Ntc%&ph`DQ;+7ws%$Z7kr*u<}m!zCjLFuLvUjCqD&Q$6+=1ir2W6o6S zHs(yFUSrO5(Ers~_3|S_9!rdw@+0GUx6PY6&L751IbO^9)0ip8OM0())1B?V&6|4O zKjuvx=RfnNWOG;82Ak_ho;PF6lpmQv|CljTMwI@^;e>fp^4p3rQ;yfN=8Tzgyq7{HE&8D%*>naY+p2QN*a=RQ_uU{m?=LpAH^6DIKL7I!b4Bl-fE^L9VJ^wsjH)OHdgvHezCG^ynkF~uOHrZS}-q@>0agOI8S zDer&@PG?HpPenE6IE$!egj7>RwVRR}bDV|LY^GEp?1wU?3SmE-DfKXw;FvQg50LX zH8_~mtCG}%gGrr|q(p+An&667Lq07@-RDsil6p;&+T~KsR8(5M{dGx7U@s(fMv@ZP z3rW2pNj*GRf^SMvdk2$xOOonyDL;i>>+xewk`l#mJtd{}`0=(RC5oYtlr2e#YQK?+ zO6y1JJCc-$s*u!KNlFw$A*pvIDNzjVRDxPR=+8+~k9aY3Qc`0cwncA$Pm&VYyQ-Al OkBRpsDG}8^QcnP?XBRI3 delta 1217 zcmeBL%yMcrGh^>WMln`HGh+*ri88k*%Q2o~Ha0ch{GPFzd2<`{5YRf^ zkdaYzsvM&Xx=;qD&?ZcwKj=b+Xy#0xz=*CybownUVk%6S>hiFN?Er}xqd0#$6EmYM zx|lB(u^Cvz9%B*HW5G195{uXoEMokun0lkIh%H4GD`aC7o&Eufm^B-Uc|dhWNL~eN zyNIDgbh->Xx;D}2=~%=zVG;X_Ma+W()4a)8#O`4c)8xd|TY^PwKNc}AE=;|lSi}}! z5qpg)rpS#MbdB7Y#-G6=CeDMYE(wd+S}bBeu!uSHVw%^7MeHUPF%>>cz4=(gc3}}? k<;T<;fJJOJs@O$-%%C$Az%;4`i`a22VnTwLdSiiN0PpI+DF6Tf diff --git a/tests/ut/data/mindrecord/testGraphData/testdata.db b/tests/ut/data/mindrecord/testGraphData/testdata.db index 0f022589f4c0c123c8a38fdb69b8abbcbaaa9668..bade6acbc6539cc67314087fa72c53baab912391 100644 GIT binary patch literal 16384 zcmeI1du$ZP9mi)Mx6hlIJKvq}?DON!yEe8l_>D1G0)sEc0c^0%Bck9sFg}dU^6(+r zAXJrB5ZXUzRH~vh6-p`vC2i>+D6Le88Z|-17#r98(TG-|NmCJ2)FurjiK@`foecx& zDE+sMW^}Xrxt))8fBNp+?d@LQ`ZdEDk{H^vfA?S}(Fg)SlECsr0sx@Sb-KzmnJ%;2 z?lR{(edkJFWCA_0+rNI)bY5)cW91VjQN0g-@6;Qus%)ruBP zCZ+9}!6!#a->UZBj)Cs>wJApG%HC9ae=5=6zG6)(!M0aswM2hvbAO_{hxQF?)>M77 zMbe}3xb$=?x4W+Hj@0H<-`q-mWJhxwD-(&{o=pQ?9f?O)Jb0R2t}Q z-^4tg)AQJ4eW`xteZDi0>ZS+&jFYXWv$to%x`(U!jB~%96>EA{JSd3g%+>S2`!7ze zphumq?v(>|*&1sSgPF|!;U^Dd2;Gnkz9u_XKRY%wynie+FuY^dO?9SvbE6HjqhnpVCtb4=SRBzWxdZK!}*0%RP zkyxF2qH^}c=E|s>Ep@h}+_EvZY|1U0tEv{rTC_bbfnje4`O4V7QF>QDFnA!dC-;6} zZn6y2&w)ZeY)1A22WfB9Hv{}qT^@>T80YYKj1X{F+2(f zVH=G3|L(uyKjZ(pe~Z7_kL_Fb>-G=rgZ6s6#@4JqTEDcuXYI9CTNRdM{=xjY`5kkY z*qniGB8f9X0FwQ8}Q*+ICh*Q!Ne}jZQ@#u=q7H6o{oh{0$iq6BiCv` z&!L;q>gZ`w3a$+M+y<^ykJh8-@?VKolM*nUR@{27RfjaRK7V=sE2J1)rB)r+s)gr~ z7MaLjPKv-3wQ9N6Lbw;6k8F%gkT{s6)=t$@z5(QJlspd*m_9y<15DXn5 z`CyVbRa~jkK5KuHcPa!)1dJ1>@=+-}>FJc8XK%O9<~ez%NZ122&OEMEVaM$4L8_6E z2gaQWu2gPK+Oc43@GQyO4QAYOo+P=}5oOfycmzx+P8rWKljig0ZA|cWf=0l&o8(DS!rW*+kKV_Gz+H4#Zh~u- znzp$S9YXID|1NM@c1yWdiSe3g(}{YB*u&t8>XvY=V&j1E8hi$6#7bv_DCSy4MzwJO z7QttTna&1L#I@r34WrtB1{M(`4W=|V&b12llll#RgZ~WC(_qr)7ILiuy-Po7zv^!w zS~{CV0oRIY_w+7%mHjI54bdc=7|$}J+Bxl>b2%*z7{IHXq><0Cuknfmp#`x6C z(CIE)PKYbz`PzKP=Wk_J`Y8i@;Hs=U z0j`AAE$WMUQa|N^DOqzcS3+t)-J-p$Cp|DJ`y9lTpmI?SXpP#-9=IZ_4&+LH7PFud~m} z@2Y+3l;>GVldBl=I{8`UO|{S4Z6*z_qRX1RPT8uw>5UpmovUcl1uDIxZ1qO;q{dZz z(tawX<#)VYTGGc=RH>3mx8<}q>`SU#MFDS7X`gi4+o>iMt|Ehz;4N@h+UKQ}5;+#ecy!@HPA!d>;P{ z{|J8{f15rH7{h721^3{UxCPJ0r8tCj^iT9R^e6N?G=YAF&Y?5t2j~R)20DmF&;aT~ zooFelMGlH08~zJ^2yepQ!>jOh_zU_3;WYd%JO&TJJ#ahR1Xsgl^uVHt1VjQN0g-@6 zKqMd%5DAC`L;@m#|J?*^bM_9)ke0SFUb2*N>k`I`TN$@3X1u6{aq}X^P0frOn;17V zGOll6TvyMywvO?_TE+_&GOk&`czzAz>iLYTsu@>SF`ie+xMCjT@(RXoIb+9VT;?!N zmN8Bw8J8v)my|LtE@50$%s5`exG>JRppbE_fN?a&I6ulblFv9CVH^rG&I>UP<}nTg z8RGzBgc(D`*bf=oe#Vx~*t8fM^pW8G`KB9;HJ!0fW32iZD=K4Iu}m}fWPm;;xKD~7i!)>_vtuwb_;1{k BLYV*n literal 16384 zcmeI1du$ZP9mi*P@AmdJGiTq~_p;Bse%ttM^Dxi~+gyx=V^iChL<(3=JU-4gOTiZ< zElH428YNL9AW|cg{*g)qkVu4@N@+p_f3!ks{lqn@8X=VcQdOc72?YX?s-VuD&0x|| z`d|M*GxF|!yuXj;@x9$U-@5-N-KjLONB8dEGn}@YfgeZ`*ks!P0MGJ957{QuedgSQ zCOv>3IC?-`#~&g77ASB8z>n}0J%}a}5DAC`L;@lKk$^}*Bp?zH35Wz9+XS{MYRGY< z=hDN^PLTdBZGD@EdfJ{$JS0y?U!rXwVGp#mcPH$J=PNUsJ&<^6!0zd#+xG76s#o(Q zUnmxnUYD}nb@gmcJeBBQKFE)Bw0yGCw)=W_40Ua`A8lPd12*fo#@^P}nV`~8U)zp{ z9?$80;)(vmz(encouNbz9r$}rWFVhiYRJs7#vV?m_otqHK27L}Oz<_CTH{h}G_`*+J(L<*auc12zHGH=sX9rX+v|?} zce#0~oZg+<&#AU7SAX`?%!qUf*>3G^eO*ts^*wEGO*~z>^x|cQYFL`i(y-Zmb9UcS zRaGmip|+R=Qtk+OadO`T{Z>CT{Cs+E_V=OXOJS&S85DZ6X_*@i(ru$A0sK!qgFnQ- z#1pt1*I*0XM6>8LdL50T&8Qdwcmnm&8I%U0T4O>69V&=EzU(Mf{ zhs_DI%Pcn)F)<%;1J0NbHu6QD3yAhp2|HIfFuu05vTG|C3E#^720jRF1IB27zw&y z(W&4l<=Us(?VK~YB_!a2DW{yHlxZ((pXRjYoFTb;XfI_PrBrLzUiM$eX(c%mV8JQn zC?%RtYxnp1FA)DYxagE{lw#kk=EIl$y#$Yg8K;<|IKHF4SxoR{g2w4|9gbrA27O1- zElda;qrKQ1rN|fa4Wa|+7O}>_v{S@U;_6jj45{bxBdG3yj8AX*a4J28$D_1;nIttRUf@g>2W6Xhro?^*91 z^SsqW)FfSWgrkJLuXxXzTg>ysJ4zQF<|rX=hxZlZlDUO=Mrpbsj*{=yydB0a;}TH_ zoo+ry$@9#4HT_Ft7m*2=v-4IediHHq3wqx6%;{ zSQ%GN`W)Xe7tG5x<`_t6P{!5MKF0-fvW++fl)q6L)MoXx3ua{-aturUO#a4uPHlF< zjBHyR!;}xopLrABb1s;cZIfdd@@DyZ^N%^`v7WZ=M z9&m#iH|3<8ihDS91)QPAKIx{ryHKHZ>ETCpc71Z13U{S_ZZawdbX}s$rb_0MJX>C$ zJ^=2zqai7vWiCTnTv0pV5`+H5u@UYC_#XZWU&nvLv-l7A_xJ<+8~iK$CO(MMIE8oN ze!K<$2shyhT!3@1ivEpmp?{#u=pwp+&Z0BuI68*jLWj@`XaWtRL9`WZLaR|VDndb| z!~5_yyaE3X7vZ1iQ-qJ;N%(7c1pWg4oDM9SNI)bY5)cW91VjQN0g-@6KqMd%csvs@ zmcN%t8#Xdrzky-vdWP#-8LnN&aLrnVtJg4WShy+9eB7w&!0Yj1t(_}I|GMpa%55UMTr2qf` diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 053f232cfb1..83f84dc7b2b 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -17,6 +17,7 @@ import pytest import numpy as np import mindspore.dataset as ds from mindspore import log as logger +from mindspore.dataset.engine import SamplingStrategy DATASET_FILE = "../data/mindrecord/testGraphData/testdata" SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" @@ -97,7 +98,10 @@ def test_graphdata_getsampledneighbors(): nodes = g.get_nodes_from_edges(edges) assert len(nodes) == 40 neighbor = g.get_sampled_neighbors( - np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) + np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM) + assert neighbor.shape == (10, 9) + neighbor = g.get_sampled_neighbors( + np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT) assert neighbor.shape == (10, 9) diff --git a/tests/ut/python/dataset/test_graphdata_distributed.py b/tests/ut/python/dataset/test_graphdata_distributed.py index 06bc8d460f0..97b4b8e137a 100644 --- a/tests/ut/python/dataset/test_graphdata_distributed.py +++ b/tests/ut/python/dataset/test_graphdata_distributed.py @@ -20,6 +20,7 @@ from multiprocessing import Process import numpy as np import mindspore.dataset as ds from mindspore import log as logger +from mindspore.dataset.engine import SamplingStrategy DATASET_FILE = "../data/mindrecord/testGraphData/testdata" @@ -68,9 +69,9 @@ class GNNGraphDataset(): neg_nodes = self.g.get_neg_sampled_neighbors( node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ - 2, 2], neighbor_types=[2, 1]) - neg_nodes_neighbors = self.g.get_sampled_neighbors( - node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) + 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM) + neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], + neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT) nodes_neighbors_features = self.g.get_node_feature( node_list=nodes_neighbors, feature_types=[2, 3]) neg_neighbors_features = self.g.get_node_feature(