diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc index 936ba2804e9..5e8c5e3a1c2 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc @@ -65,9 +65,9 @@ PYBIND_REGISTER( }) .def("get_sampled_neighbors", [](gnn::GraphData &g, std::vector node_list, std::vector neighbor_nums, - std::vector neighbor_types) { + std::vector neighbor_types, SamplingStrategy strategy) { std::shared_ptr out; - THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); + THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out)); return out; }) .def("get_neg_sampled_neighbors", @@ -114,8 +114,15 @@ PYBIND_REGISTER( return out; })) .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) - .def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); + .def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); }); })); +PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) { + (void)py::enum_(*m, "SamplingStrategy", py::arithmetic()) + .value("DE_SAMPLING_RANDOM", SamplingStrategy::kRandom) + .value("DE_SAMPLING_EDGE_WEIGHT", SamplingStrategy::kEdgeWeight) + .export_values(); + })); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/core/constants.h b/mindspore/ccsrc/minddata/dataset/core/constants.h index 480eb682f0e..d749827c880 100644 --- a/mindspore/ccsrc/minddata/dataset/core/constants.h +++ b/mindspore/ccsrc/minddata/dataset/core/constants.h @@ -71,6 +71,9 @@ enum class NormalizeForm { kNfkd, }; +// Possible values for SamplingStrategy +enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 }; + // convenience functions for 32bit int bitmask inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h index b11c20bf3a5..b80e73de3f4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/edge.h @@ -35,10 +35,11 @@ class Edge { // Constructor // @param EdgeIdType id - edge id // @param EdgeType type - edge type + // @param WeightType weight - edge weight // @param std::shared_ptr src_node - source node // @param std::shared_ptr dst_node - destination node - Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} + Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr src_node, std::shared_ptr dst_node) + : id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {} virtual ~Edge() = default; @@ -48,6 +49,9 @@ class Edge { // @return NodeIdType - Returned edge type EdgeType type() const { return type_; } + // @return WeightType - Returned edge weight + WeightType weight() const { return weight_; } + // Get the feature of a edge // @param FeatureType feature_type - type of feature // @param std::shared_ptr *out_feature - Returned feature @@ -77,6 +81,7 @@ class Edge { protected: EdgeIdType id_; EdgeType type_; + WeightType weight_; std::shared_ptr src_node_; std::shared_ptr dst_node_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto index 1342d047cc5..f95a823eb7d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto @@ -71,6 +71,7 @@ message GnnGraphDataRequestPb { repeated int32 number = 4; // samples number TensorPb id_tensor = 5; // input ids ,node id or edge id GnnRandomWalkPb random_walk = 6; + int32 strategy = 7; } message GnnGraphDataResponsePb { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h index a790c363f38..c50bf194dd5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h @@ -76,11 +76,13 @@ class GraphData { // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::SamplingStrategy strategy - Sampling strategy // @param std::shared_ptr *out - Returned neighbor's id. // @return Status The status code returned virtual Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) = 0; + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) = 0; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -95,7 +97,7 @@ class GraphData { // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param float step_away_param - in out hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc index 6fdde154268..a9f618ccbe9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc @@ -137,7 +137,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector &node_list Status GraphDataClient::GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) { #if !defined(_WIN32) && !defined(_WIN64) GnnGraphDataRequestPb request; GnnGraphDataResponsePb response; @@ -151,6 +152,7 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector &node_ for (const auto &type : neighbor_types) { request.add_type(static_cast(type)); } + request.set_strategy(static_cast(strategy)); RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); #endif return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h index 36cb9169a42..0e8d08f11bf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h @@ -86,10 +86,12 @@ class GraphDataClient : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::SamplingStrategy strategy - Sampling strategy // @param std::shared_ptr *out - Returned neighbor's id. // @return Status The status code returned Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) override; + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) override; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -104,7 +106,7 @@ class GraphDataClient : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param float step_away_param - in out hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status The status code returned diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc index b97cad0ffaa..70a090bb97b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc @@ -171,7 +171,8 @@ Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) { Status GraphDataImpl::GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), "The sizes of neighbor_nums and neighbor_types are inconsistent."); @@ -199,7 +200,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector &node_li std::shared_ptr node; RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node)); std::vector out; - RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out)); + RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out)); neighbors.insert(neighbors.end(), out.begin(), out.end()); } } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h index 6b6fc05d475..b5db50768b5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h @@ -80,10 +80,12 @@ class GraphDataImpl : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector neighbor_nums - Number of neighbors sampled per hop // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::SamplingStrategy strategy - Sampling strategy // @param std::shared_ptr *out - Returned neighbor's id. // @return Status The status code returned Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) override; + const std::vector &neighbor_types, SamplingStrategy strategy, + std::shared_ptr *out) override; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -98,7 +100,7 @@ class GraphDataImpl : public GraphData { // @param std::vector node_list - List of nodes // @param std::vector meta_path - node type of each step // @param float step_home_param - return hyper parameter in node2vec algorithm - // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param float step_away_param - in out hyper parameter in node2vec algorithm // @param NodeIdType default_node - default node id // @param std::shared_ptr *out - Returned nodes id in walk path // @return Status The status code returned @@ -194,7 +196,7 @@ class GraphDataImpl : public GraphData { std::vector node_list_; std::vector meta_path_; float step_home_param_; // Return hyper parameter. Default is 1.0 - float step_away_param_; // Inout hyper parameter. Default is 1.0 + float step_away_param_; // In out hyper parameter. Default is 1.0 NodeIdType default_node_; int32_t num_walks_; // Number of walks per source. Default is 1 diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h index 2bf1ad2b51e..49c57a1c3e1 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h @@ -50,7 +50,7 @@ class GraphDataServer { enum ServerState state() { return state_; } - bool IsStoped() { + bool IsStopped() { if (state_ == kGdsStopped) { return true; } else { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc index 3fe4c96ce81..04b930bb557 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc @@ -78,7 +78,7 @@ grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context, } break; case GraphDataServer::kGdsStopped: - response->set_error_msg("Stoped"); + response->set_error_msg("Stopped"); break; } } else { @@ -222,8 +222,9 @@ Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *re neighbor_types.resize(request->type().size()); std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), [](const google::protobuf::int32 type) { return static_cast(type); }); + SamplingStrategy strategy = static_cast(request->strategy()); std::shared_ptr tensor; - RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); + RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &tensor)); TensorPb *result = response->add_result_data(); RETURN_IF_NOT_OK(TensorToPb(tensor, result)); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 1d043ced745..16dcfa4d3a0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -39,7 +39,9 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int row_id_(0), shard_reader_(nullptr), graph_feature_parser_(nullptr), - keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} + required_key_( + {"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}), + optional_key_({{"weight", false}}) {} Status GraphLoader::GetNodesAndEdges() { NodeIdMap *n_id_map = &graph_impl_->node_id_map_; @@ -62,7 +64,7 @@ Status GraphLoader::GetNodesAndEdges() { CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); - RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); + RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight())); e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); dq.pop_front(); @@ -95,12 +97,18 @@ Status GraphLoader::InitAndLoad() { graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); mindrecord::json schema = graph_impl_->data_schema_["schema"]; - for (const std::string &key : keys_) { + for (const std::string &key : required_key_) { if (schema.find(key) == schema.end()) { RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); } } + for (auto op_key : optional_key_) { + if (schema.find(op_key.first) != schema.end()) { + optional_key_[op_key.first] = true; + } + } + if (graph_impl_->server_mode_) { #if !defined(_WIN32) && !defined(_WIN64) int64_t total_blob_size = 0; @@ -128,7 +136,11 @@ Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrec DefaultNodeFeatureMap *default_feature) { NodeIdType node_id = col_jsn["first_id"]; NodeType node_type = static_cast(col_jsn["type"]); - (*node) = std::make_shared(node_id, node_type); + WeightType weight = 1; + if (optional_key_["weight"]) { + weight = col_jsn["weight"]; + } + (*node) = std::make_shared(node_id, node_type, weight); std::vector indices; RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); if (graph_impl_->server_mode_) { @@ -174,9 +186,13 @@ Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrec EdgeIdType edge_id = col_jsn["first_id"]; EdgeType edge_type = static_cast(col_jsn["type"]); NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; - std::shared_ptr src = std::make_shared(src_id, -1); - std::shared_ptr dst = std::make_shared(dst_id, -1); - (*edge) = std::make_shared(edge_id, edge_type, src, dst); + WeightType edge_weight = 1; + if (optional_key_["weight"]) { + edge_weight = col_jsn["weight"]; + } + std::shared_ptr src = std::make_shared(src_id, -1, 1); + std::shared_ptr dst = std::make_shared(dst_id, -1, 1); + (*edge) = std::make_shared(edge_id, edge_type, edge_weight, src, dst); std::vector indices; RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); if (graph_impl_->server_mode_) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h index 7b397547fb5..f3a7fd0edff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h @@ -110,7 +110,8 @@ class GraphLoader { std::vector e_feature_maps_; std::vector default_node_feature_maps_; std::vector default_edge_feature_maps_; - const std::vector keys_; + const std::vector required_key_; + std::unordered_map optional_key_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc index d20be6e318e..27eb508ff6d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc @@ -21,8 +21,9 @@ namespace mindspore { namespace dataset { namespace gnn { -LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) - : Edge(id, type, src_node, dst_node) {} +LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr src_node, + std::shared_ptr dst_node) + : Edge(id, type, weight, src_node, dst_node) {} Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = features_.find(feature_type); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h index 9c365723b71..fa860fb9297 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h @@ -34,9 +34,11 @@ class LocalEdge : public Edge { // Constructor // @param EdgeIdType id - edge id // @param EdgeType type - edge type + // @param WeightType weight - edge weight // @param std::shared_ptr src_node - source node // @param std::shared_ptr dst_node - destination node - LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); + LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr src_node, + std::shared_ptr dst_node); ~LocalEdge() = default; diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc index 8eaf9bb7163..bd7114f571a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.cc @@ -16,6 +16,7 @@ #include "minddata/dataset/engine/gnn/local_node.h" #include +#include #include #include @@ -26,7 +27,10 @@ namespace mindspore { namespace dataset { namespace gnn { -LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); } +LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight) + : Node(id, type, weight), rnd_(GetRandomDevice()) { + rnd_.seed(GetSeed()); +} Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = features_.find(feature_type); @@ -44,13 +48,13 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vectorsecond.size()); - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(), + neighbors.resize(itr->second.first.size()); + std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(), [](const std::shared_ptr node) { return node->id(); }); } else { - neighbors.resize(itr->second.size() + 1); + neighbors.resize(itr->second.first.size() + 1); neighbors[0] = id_; - std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1, [](const std::shared_ptr node) { return node->id(); }); } } else { @@ -63,8 +67,8 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector> &neighbors, int32_t samples_num, - std::vector *out) { +Status LocalNode::GetRandomSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out) { std::vector shuffled_id(neighbors.size()); std::iota(shuffled_id.begin(), shuffled_id.end(), 0); std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); @@ -75,14 +79,33 @@ Status LocalNode::GetSampledNeighbors(const std::vector> & return Status::OK(); } -Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, +Status LocalNode::GetWeightSampledNeighbors(const std::vector> &neighbors, + const std::vector &weights, int32_t samples_num, + std::vector *out) { + CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(), + "The number of neighbors does not match the weight."); + std::discrete_distribution discrete_dist(weights.begin(), weights.end()); + for (int32_t i = 0; i < samples_num; ++i) { + NodeIdType index = discrete_dist(rnd_); + out->emplace_back(neighbors[index]->id()); + } + return Status::OK(); +} + +Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, std::vector *out_neighbors) { std::vector neighbors; neighbors.reserve(samples_num); auto itr = neighbor_nodes_.find(neighbor_type); if (itr != neighbor_nodes_.end()) { - while (neighbors.size() < samples_num) { - RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors)); + if (strategy == SamplingStrategy::kRandom) { + while (neighbors.size() < samples_num) { + RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors)); + } + } else if (strategy == SamplingStrategy::kEdgeWeight) { + RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors)); + } else { + RETURN_STATUS_UNEXPECTED("Invalid strategy"); } } else { MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; @@ -95,12 +118,15 @@ Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_nu return Status::OK(); } -Status LocalNode::AddNeighbor(const std::shared_ptr &node) { +Status LocalNode::AddNeighbor(const std::shared_ptr &node, const WeightType &weight) { auto itr = neighbor_nodes_.find(node->type()); if (itr != neighbor_nodes_.end()) { - itr->second.push_back(node); + itr->second.first.push_back(node); + itr->second.second.push_back(weight); } else { - neighbor_nodes_[node->type()] = {node}; + std::vector> nodes = {node}; + std::vector weights = {weight}; + neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights)); } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h index 75ba021ee28..7ec030556a8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -18,6 +18,7 @@ #include #include +#include #include #include "minddata/dataset/engine/gnn/node.h" @@ -33,7 +34,7 @@ class LocalNode : public Node { // Constructor // @param NodeIdType id - node id // @param NodeType type - node type - LocalNode(NodeIdType id, NodeType type); + LocalNode(NodeIdType id, NodeType type, WeightType weight); ~LocalNode() = default; @@ -53,15 +54,16 @@ class LocalNode : public Node { // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor // @param int32_t samples_num - Number of neighbors to be acquired + // @param SamplingStrategy strategy - Sampling strategy // @param std::vector *out_neighbors - Returned neighbors id // @return Status The status code returned - Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, std::vector *out_neighbors) override; // Add neighbor of node // @param std::shared_ptr node - // @return Status The status code returned - Status AddNeighbor(const std::shared_ptr &node) override; + Status AddNeighbor(const std::shared_ptr &node, const WeightType &) override; // Update feature of node // @param std::shared_ptr feature - @@ -69,12 +71,16 @@ class LocalNode : public Node { Status UpdateFeature(const std::shared_ptr &feature) override; private: - Status GetSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, - std::vector *out); + Status GetRandomSampledNeighbors(const std::vector> &neighbors, int32_t samples_num, + std::vector *out); + + Status GetWeightSampledNeighbors(const std::vector> &neighbors, + const std::vector &weights, int32_t samples_num, + std::vector *out); std::mt19937 rnd_; std::unordered_map> features_; - std::unordered_map>> neighbor_nodes_; + std::unordered_map>, std::vector>> neighbor_nodes_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h index 1f5aaf55a8b..3382df5c244 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -28,6 +28,7 @@ namespace dataset { namespace gnn { using NodeType = int8_t; using NodeIdType = int32_t; +using WeightType = float; constexpr NodeIdType kDefaultNodeId = -1; @@ -36,7 +37,8 @@ class Node { // Constructor // @param NodeIdType id - node id // @param NodeType type - node type - Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} + // @param WeightType type - node weight + Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {} virtual ~Node() = default; @@ -46,6 +48,9 @@ class Node { // @return NodeIdType - Returned node type NodeType type() const { return type_; } + // @return WeightType - Returned node weight + WeightType weight() const { return weight_; } + // Get the feature of a node // @param FeatureType feature_type - type of feature // @param std::shared_ptr *out_feature - Returned feature @@ -62,15 +67,16 @@ class Node { // Get the sampled neighbors of a node // @param NodeType neighbor_type - type of neighbor // @param int32_t samples_num - Number of neighbors to be acquired + // @param SamplingStrategy strategy - Sampling strategy // @param std::vector *out_neighbors - Returned neighbors id // @return Status The status code returned - virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, + virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy, std::vector *out_neighbors) = 0; // Add neighbor of node // @param std::shared_ptr node - // @return Status The status code returned - virtual Status AddNeighbor(const std::shared_ptr &node) = 0; + virtual Status AddNeighbor(const std::shared_ptr &node, const WeightType &weight) = 0; // Update feature of node // @param std::shared_ptr feature - @@ -80,6 +86,7 @@ class Node { protected: NodeIdType id_; NodeType type_; + WeightType weight_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/include/constants.h b/mindspore/ccsrc/minddata/dataset/include/constants.h index e5474847f4a..2f34ea573dc 100644 --- a/mindspore/ccsrc/minddata/dataset/include/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/constants.h @@ -71,6 +71,9 @@ enum class NormalizeForm { kNfkd, }; +// Possible values for SamplingStrategy +enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 }; + // convenience functions for 32bit int bitmask inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 7609d2266ec..b2ab05150eb 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip. from ..core import config from .cache_client import DatasetCache from .datasets import * -from .graphdata import GraphData +from .graphdata import GraphData, SamplingStrategy from .iterators import * from .samplers import * from .serializer_deserializer import compare, deserialize, serialize, show diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 58fe40f47d9..68afda3f3be 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -18,10 +18,12 @@ and provides operations related to graph data. """ import atexit import time +from enum import IntEnum import numpy as np from mindspore._c_dataengine import GraphDataClient from mindspore._c_dataengine import GraphDataServer from mindspore._c_dataengine import Tensor +from mindspore._c_dataengine import SamplingStrategy as Sampling from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \ @@ -29,6 +31,17 @@ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_ check_gnn_random_walk +class SamplingStrategy(IntEnum): + RANDOM = 0 + EDGE_WEIGHT = 1 + + +DE_C_INTER_SAMPLING_STRATEGY = { + SamplingStrategy.RANDOM: Sampling.DE_SAMPLING_RANDOM, + SamplingStrategy.EDGE_WEIGHT: Sampling.DE_SAMPLING_EDGE_WEIGHT, +} + + class GraphData: """ Reads the graph dataset used for GNN training from the shared file and database. @@ -86,7 +99,7 @@ class GraphData: dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) atexit.register(stop) try: - while self._graph_data.is_stoped() is not True: + while self._graph_data.is_stopped() is not True: time.sleep(1) except KeyboardInterrupt: raise Exception("Graph data server receives KeyboardInterrupt.") @@ -185,7 +198,7 @@ class GraphData: return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() @check_gnn_get_sampled_neighbors - def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): + def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM): """ Get sampled neighbor information. @@ -199,6 +212,11 @@ class GraphData: node_list (Union[list, numpy.ndarray]): The given list of nodes. neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop. neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop. + strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM). + It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]. + + - SamplingStrategy.RANDOM, random sampling with replacement. + - SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability. Returns: numpy.ndarray, array of neighbors. @@ -215,10 +233,12 @@ class GraphData: TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray. """ + if not isinstance(strategy, SamplingStrategy): + raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.") if self._working_mode == 'server': raise Exception("This method is not supported when working mode is server.") return self._graph_data.get_sampled_neighbors( - node_list, neighbor_nums, neighbor_types).as_array() + node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array() @check_gnn_get_neg_sampled_neighbors def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type): @@ -342,7 +362,7 @@ class GraphData: target_nodes (list[int]): Start node list in random walk meta_path (list[int]): node type for each walk step step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0). - step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0). + step_away_param (float, optional): in out hyper parameter in node2vec algorithm (Default = 1.0). default_node (int, optional): default node if no more neighbors found (Default = -1). A default value of -1 indicates that no node is given. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 724207d5aea..39af3cbb384 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1114,7 +1114,7 @@ def check_gnn_get_sampled_neighbors(method): @wraps(method) def new_method(self, *args, **kwargs): - [node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs) + [node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs) check_gnn_list_or_ndarray(node_list, 'node_list') diff --git a/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py b/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py index 1da1ced2f7d..1ee1fa343ff 100644 --- a/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py +++ b/model_zoo/utils/graph_to_mindrecord/graph_map_schema.py @@ -37,6 +37,7 @@ class GraphMapSchema: "second_id": {"type": "int64"}, "third_id": {"type": "int64"}, "type": {"type": "int32"}, + "weight": {"type": "float32"}, "attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge "node_feature_index": {"type": "int32", "shape": [-1]}, "edge_feature_index": {"type": "int32", "shape": [-1]} @@ -91,8 +92,11 @@ class GraphMapSchema: logger.info("node cannot be None.") raise ValueError("node cannot be None.") - node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], - "node_feature_index": []} + node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n', + "type": node["type"], "node_feature_index": []} + if "weight" in node: + node_graph["weight"] = node["weight"] + for i in range(self.num_node_features): k = i + 1 node_field_key = 'feature_' + str(k) @@ -129,8 +133,11 @@ class GraphMapSchema: logger.info("edge cannot be None.") raise ValueError("edge cannot be None.") - edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', - "type": edge["type"], "edge_feature_index": []} + edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0, + "attribute": 'e', "type": edge["type"], "edge_feature_index": []} + + if "weight" in edge: + edge_graph["weight"] = edge["weight"] for i in range(self.num_edge_features): k = i + 1 diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 66ad0a6e857..521ffff6342 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -15,6 +15,7 @@ */ #include #include +#include #include #include @@ -38,6 +39,60 @@ using namespace mindspore::dataset::gnn; class MindDataTestGNNGraph : public UT::Common { protected: MindDataTestGNNGraph() = default; + + using NumNeighborsMap = std::map; + using NodeNeighborsMap = std::map; + void ParsingNeighbors(const std::shared_ptr &neighbors, NodeNeighborsMap &node_neighbors) { + auto shape_vec = neighbors->shape().AsVector(); + uint32_t num_members = 1; + for (size_t i = 1; i < shape_vec.size(); ++i) { + num_members *= shape_vec[i]; + } + uint32_t index = 0; + NodeIdType src_node = 0; + for (auto node_itr = neighbors->begin(); node_itr != neighbors->end(); + ++node_itr, ++index) { + if (index % num_members == 0) { + src_node = *node_itr; + continue; + } + auto src_node_itr = node_neighbors.find(src_node); + if (src_node_itr == node_neighbors.end()) { + node_neighbors[src_node] = {{*node_itr, 1}}; + } else { + auto nei_itr = src_node_itr->second.find(*node_itr); + if (nei_itr == src_node_itr->second.end()) { + src_node_itr->second[*node_itr] = 1; + } else { + src_node_itr->second[*node_itr] += 1; + } + } + } + } + + void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector &weights, + float deviation_ratio = 0.1) { + EXPECT_EQ(number_neighbors.size(), weights.size()); + int index = 0; + uint32_t pre_num = 0; + WeightType pre_weight = 1; + for (auto neighbor : number_neighbors) { + if (pre_num != 0) { + float target_ratio = static_cast(pre_weight) / static_cast(weights[index]); + float current_ratio = static_cast(pre_num) / static_cast(neighbor.second); + float target_upper = target_ratio * (1 + deviation_ratio); + float target_lower = target_ratio * (1 - deviation_ratio); + MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio) + << " target_upper:" << std::to_string(target_upper) + << " target_lower:" << std::to_string(target_lower); + EXPECT_LE(current_ratio, target_upper); + EXPECT_GE(current_ratio, target_lower); + } + pre_num = neighbor.second; + pre_weight = weights[index]; + ++index; + } + } }; TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { @@ -131,44 +186,75 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; }); std::shared_ptr neighbors; - s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors); - EXPECT_TRUE(s.IsOk()); - EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + { + MS_LOG(INFO) << "Test random sampling."; + NodeNeighborsMap number_neighbors; + int count = 0; + while (count < 1000) { + neighbors.reset(); + s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + ParsingNeighbors(neighbors, number_neighbors); + ++count; + } + CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1}); + } + + { + MS_LOG(INFO) << "Test edge weight sampling."; + NodeNeighborsMap number_neighbors; + int count = 0; + while (count < 1000) { + neighbors.reset(); + s = + graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>"); + ParsingNeighbors(neighbors, number_neighbors); + ++count; + } + CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8}); + } neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>"); neighbors.reset(); s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, - {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); + {meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>"); neighbors.reset(); - s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors); + s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, + SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") != std::string::npos); neighbors.reset(); - s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors); + s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors); EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos); } diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata index ad97dbae98d..a52fb92282e 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata and b/tests/ut/data/mindrecord/testGraphData/testdata differ diff --git a/tests/ut/data/mindrecord/testGraphData/testdata.db b/tests/ut/data/mindrecord/testGraphData/testdata.db index 0f022589f4c..bade6acbc65 100644 Binary files a/tests/ut/data/mindrecord/testGraphData/testdata.db and b/tests/ut/data/mindrecord/testGraphData/testdata.db differ diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 053f232cfb1..83f84dc7b2b 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -17,6 +17,7 @@ import pytest import numpy as np import mindspore.dataset as ds from mindspore import log as logger +from mindspore.dataset.engine import SamplingStrategy DATASET_FILE = "../data/mindrecord/testGraphData/testdata" SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" @@ -97,7 +98,10 @@ def test_graphdata_getsampledneighbors(): nodes = g.get_nodes_from_edges(edges) assert len(nodes) == 40 neighbor = g.get_sampled_neighbors( - np.unique(nodes[0:21, 0]), [2, 3], [2, 1]) + np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM) + assert neighbor.shape == (10, 9) + neighbor = g.get_sampled_neighbors( + np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT) assert neighbor.shape == (10, 9) diff --git a/tests/ut/python/dataset/test_graphdata_distributed.py b/tests/ut/python/dataset/test_graphdata_distributed.py index 06bc8d460f0..97b4b8e137a 100644 --- a/tests/ut/python/dataset/test_graphdata_distributed.py +++ b/tests/ut/python/dataset/test_graphdata_distributed.py @@ -20,6 +20,7 @@ from multiprocessing import Process import numpy as np import mindspore.dataset as ds from mindspore import log as logger +from mindspore.dataset.engine import SamplingStrategy DATASET_FILE = "../data/mindrecord/testGraphData/testdata" @@ -68,9 +69,9 @@ class GNNGraphDataset(): neg_nodes = self.g.get_neg_sampled_neighbors( node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ - 2, 2], neighbor_types=[2, 1]) - neg_nodes_neighbors = self.g.get_sampled_neighbors( - node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) + 2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM) + neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], + neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT) nodes_neighbors_features = self.g.get_node_feature( node_list=nodes_neighbors, feature_types=[2, 3]) neg_neighbors_features = self.g.get_node_feature(