forked from mindspore-Ecosystem/mindspore
gnn support weight sampling neighbors.
This commit is contained in:
parent
3708624a25
commit
2dc9ba761c
|
@ -65,9 +65,9 @@ PYBIND_REGISTER(
|
|||
})
|
||||
.def("get_sampled_neighbors",
|
||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
||||
std::vector<gnn::NodeType> neighbor_types) {
|
||||
std::vector<gnn::NodeType> neighbor_types, SamplingStrategy strategy) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
|
||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_neg_sampled_neighbors",
|
||||
|
@ -114,8 +114,15 @@ PYBIND_REGISTER(
|
|||
return out;
|
||||
}))
|
||||
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
|
||||
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); });
|
||||
.def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); });
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<SamplingStrategy>(*m, "SamplingStrategy", py::arithmetic())
|
||||
.value("DE_SAMPLING_RANDOM", SamplingStrategy::kRandom)
|
||||
.value("DE_SAMPLING_EDGE_WEIGHT", SamplingStrategy::kEdgeWeight)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,6 +71,9 @@ enum class NormalizeForm {
|
|||
kNfkd,
|
||||
};
|
||||
|
||||
// Possible values for SamplingStrategy
|
||||
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
|
||||
|
||||
// convenience functions for 32bit int bitmask
|
||||
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }
|
||||
|
||||
|
|
|
@ -35,10 +35,11 @@ class Edge {
|
|||
// Constructor
|
||||
// @param EdgeIdType id - edge id
|
||||
// @param EdgeType type - edge type
|
||||
// @param WeightType weight - edge weight
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
Edge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
|
||||
: id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {}
|
||||
Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
|
||||
: id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {}
|
||||
|
||||
virtual ~Edge() = default;
|
||||
|
||||
|
@ -48,6 +49,9 @@ class Edge {
|
|||
// @return NodeIdType - Returned edge type
|
||||
EdgeType type() const { return type_; }
|
||||
|
||||
// @return WeightType - Returned edge weight
|
||||
WeightType weight() const { return weight_; }
|
||||
|
||||
// Get the feature of a edge
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
|
@ -77,6 +81,7 @@ class Edge {
|
|||
protected:
|
||||
EdgeIdType id_;
|
||||
EdgeType type_;
|
||||
WeightType weight_;
|
||||
std::shared_ptr<Node> src_node_;
|
||||
std::shared_ptr<Node> dst_node_;
|
||||
};
|
||||
|
|
|
@ -71,6 +71,7 @@ message GnnGraphDataRequestPb {
|
|||
repeated int32 number = 4; // samples number
|
||||
TensorPb id_tensor = 5; // input ids ,node id or edge id
|
||||
GnnRandomWalkPb random_walk = 6;
|
||||
int32 strategy = 7;
|
||||
}
|
||||
|
||||
message GnnGraphDataResponsePb {
|
||||
|
|
|
@ -76,11 +76,13 @@ class GraphData {
|
|||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::SamplingStrategy strategy - Sampling strategy
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status The status code returned
|
||||
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
|
||||
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get negative sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
|
@ -95,7 +97,7 @@ class GraphData {
|
|||
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||
// @param std::vector<NodeType> meta_path - node type of each step
|
||||
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - in out hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status The status code returned
|
||||
|
|
|
@ -137,7 +137,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list
|
|||
|
||||
Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
GnnGraphDataRequestPb request;
|
||||
GnnGraphDataResponsePb response;
|
||||
|
@ -151,6 +152,7 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_
|
|||
for (const auto &type : neighbor_types) {
|
||||
request.add_type(static_cast<google::protobuf::int32>(type));
|
||||
}
|
||||
request.set_strategy(static_cast<google::protobuf::int32>(strategy));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#endif
|
||||
return Status::OK();
|
||||
|
|
|
@ -86,10 +86,12 @@ class GraphDataClient : public GraphData {
|
|||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::SamplingStrategy strategy - Sampling strategy
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get negative sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
|
@ -104,7 +106,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||
// @param std::vector<NodeType> meta_path - node type of each step
|
||||
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - in out hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status The status code returned
|
||||
|
|
|
@ -171,7 +171,8 @@ Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) {
|
|||
|
||||
Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||
const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
|
||||
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
|
||||
|
@ -199,7 +200,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
|
|||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node));
|
||||
std::vector<NodeIdType> out;
|
||||
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out));
|
||||
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out));
|
||||
neighbors.insert(neighbors.end(), out.begin(), out.end());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -80,10 +80,12 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||
// @param std::SamplingStrategy strategy - Sampling strategy
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get negative sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
|
@ -98,7 +100,7 @@ class GraphDataImpl : public GraphData {
|
|||
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||
// @param std::vector<NodeType> meta_path - node type of each step
|
||||
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - in out hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status The status code returned
|
||||
|
@ -194,7 +196,7 @@ class GraphDataImpl : public GraphData {
|
|||
std::vector<NodeIdType> node_list_;
|
||||
std::vector<NodeType> meta_path_;
|
||||
float step_home_param_; // Return hyper parameter. Default is 1.0
|
||||
float step_away_param_; // Inout hyper parameter. Default is 1.0
|
||||
float step_away_param_; // In out hyper parameter. Default is 1.0
|
||||
NodeIdType default_node_;
|
||||
|
||||
int32_t num_walks_; // Number of walks per source. Default is 1
|
||||
|
|
|
@ -50,7 +50,7 @@ class GraphDataServer {
|
|||
|
||||
enum ServerState state() { return state_; }
|
||||
|
||||
bool IsStoped() {
|
||||
bool IsStopped() {
|
||||
if (state_ == kGdsStopped) {
|
||||
return true;
|
||||
} else {
|
||||
|
|
|
@ -78,7 +78,7 @@ grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context,
|
|||
}
|
||||
break;
|
||||
case GraphDataServer::kGdsStopped:
|
||||
response->set_error_msg("Stoped");
|
||||
response->set_error_msg("Stopped");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
|
@ -222,8 +222,9 @@ Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *re
|
|||
neighbor_types.resize(request->type().size());
|
||||
std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(),
|
||||
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
|
||||
SamplingStrategy strategy = static_cast<SamplingStrategy>(request->strategy());
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor));
|
||||
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &tensor));
|
||||
TensorPb *result = response->add_result_data();
|
||||
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||
return Status::OK();
|
||||
|
|
|
@ -39,7 +39,9 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int
|
|||
row_id_(0),
|
||||
shard_reader_(nullptr),
|
||||
graph_feature_parser_(nullptr),
|
||||
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
||||
required_key_(
|
||||
{"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}),
|
||||
optional_key_({{"weight", false}}) {}
|
||||
|
||||
Status GraphLoader::GetNodesAndEdges() {
|
||||
NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
|
||||
|
@ -62,7 +64,7 @@ Status GraphLoader::GetNodesAndEdges() {
|
|||
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
|
||||
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight()));
|
||||
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
|
||||
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
|
||||
dq.pop_front();
|
||||
|
@ -95,12 +97,18 @@ Status GraphLoader::InitAndLoad() {
|
|||
|
||||
graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
|
||||
mindrecord::json schema = graph_impl_->data_schema_["schema"];
|
||||
for (const std::string &key : keys_) {
|
||||
for (const std::string &key : required_key_) {
|
||||
if (schema.find(key) == schema.end()) {
|
||||
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
|
||||
}
|
||||
}
|
||||
|
||||
for (auto op_key : optional_key_) {
|
||||
if (schema.find(op_key.first) != schema.end()) {
|
||||
optional_key_[op_key.first] = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (graph_impl_->server_mode_) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
int64_t total_blob_size = 0;
|
||||
|
@ -128,7 +136,11 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
|
|||
DefaultNodeFeatureMap *default_feature) {
|
||||
NodeIdType node_id = col_jsn["first_id"];
|
||||
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
|
||||
(*node) = std::make_shared<LocalNode>(node_id, node_type);
|
||||
WeightType weight = 1;
|
||||
if (optional_key_["weight"]) {
|
||||
weight = col_jsn["weight"];
|
||||
}
|
||||
(*node) = std::make_shared<LocalNode>(node_id, node_type, weight);
|
||||
std::vector<int32_t> indices;
|
||||
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices));
|
||||
if (graph_impl_->server_mode_) {
|
||||
|
@ -174,9 +186,13 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
|
|||
EdgeIdType edge_id = col_jsn["first_id"];
|
||||
EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
|
||||
NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
|
||||
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
|
||||
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
|
||||
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
|
||||
WeightType edge_weight = 1;
|
||||
if (optional_key_["weight"]) {
|
||||
edge_weight = col_jsn["weight"];
|
||||
}
|
||||
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1, 1);
|
||||
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1, 1);
|
||||
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, edge_weight, src, dst);
|
||||
std::vector<int32_t> indices;
|
||||
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
|
||||
if (graph_impl_->server_mode_) {
|
||||
|
|
|
@ -110,7 +110,8 @@ class GraphLoader {
|
|||
std::vector<EdgeFeatureMap> e_feature_maps_;
|
||||
std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
|
||||
std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
|
||||
const std::vector<std::string> keys_;
|
||||
const std::vector<std::string> required_key_;
|
||||
std::unordered_map<std::string, bool> optional_key_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
|
|
|
@ -21,8 +21,9 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
|
||||
: Edge(id, type, src_node, dst_node) {}
|
||||
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
|
||||
std::shared_ptr<Node> dst_node)
|
||||
: Edge(id, type, weight, src_node, dst_node) {}
|
||||
|
||||
Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
|
|
|
@ -34,9 +34,11 @@ class LocalEdge : public Edge {
|
|||
// Constructor
|
||||
// @param EdgeIdType id - edge id
|
||||
// @param EdgeType type - edge type
|
||||
// @param WeightType weight - edge weight
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node);
|
||||
LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
|
||||
std::shared_ptr<Node> dst_node);
|
||||
|
||||
~LocalEdge() = default;
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "minddata/dataset/engine/gnn/local_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
|
@ -26,7 +27,10 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); }
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight)
|
||||
: Node(id, type, weight), rnd_(GetRandomDevice()) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
|
@ -44,13 +48,13 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
|
|||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
if (exclude_itself) {
|
||||
neighbors.resize(itr->second.size());
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(),
|
||||
neighbors.resize(itr->second.first.size());
|
||||
std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(),
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
} else {
|
||||
neighbors.resize(itr->second.size() + 1);
|
||||
neighbors.resize(itr->second.first.size() + 1);
|
||||
neighbors[0] = id_;
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
|
||||
std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1,
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
}
|
||||
} else {
|
||||
|
@ -63,8 +67,8 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out) {
|
||||
Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out) {
|
||||
std::vector<NodeIdType> shuffled_id(neighbors.size());
|
||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
|
@ -75,14 +79,33 @@ Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
|
||||
const std::vector<WeightType> &weights, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(),
|
||||
"The number of neighbors does not match the weight.");
|
||||
std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end());
|
||||
for (int32_t i = 0; i < samples_num; ++i) {
|
||||
NodeIdType index = discrete_dist(rnd_);
|
||||
out->emplace_back(neighbors[index]->id());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
|
||||
std::vector<NodeIdType> *out_neighbors) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
neighbors.reserve(samples_num);
|
||||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
while (neighbors.size() < samples_num) {
|
||||
RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors));
|
||||
if (strategy == SamplingStrategy::kRandom) {
|
||||
while (neighbors.size() < samples_num) {
|
||||
RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors));
|
||||
}
|
||||
} else if (strategy == SamplingStrategy::kEdgeWeight) {
|
||||
RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid strategy");
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
|
||||
|
@ -95,12 +118,15 @@ Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_nu
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node) {
|
||||
Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) {
|
||||
auto itr = neighbor_nodes_.find(node->type());
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
itr->second.push_back(node);
|
||||
itr->second.first.push_back(node);
|
||||
itr->second.second.push_back(weight);
|
||||
} else {
|
||||
neighbor_nodes_[node->type()] = {node};
|
||||
std::vector<std::shared_ptr<Node>> nodes = {node};
|
||||
std::vector<WeightType> weights = {weight};
|
||||
neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/node.h"
|
||||
|
@ -33,7 +34,7 @@ class LocalNode : public Node {
|
|||
// Constructor
|
||||
// @param NodeIdType id - node id
|
||||
// @param NodeType type - node type
|
||||
LocalNode(NodeIdType id, NodeType type);
|
||||
LocalNode(NodeIdType id, NodeType type, WeightType weight);
|
||||
|
||||
~LocalNode() = default;
|
||||
|
||||
|
@ -53,15 +54,16 @@ class LocalNode : public Node {
|
|||
// Get the sampled neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired
|
||||
// @param SamplingStrategy strategy - Sampling strategy
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
|
||||
std::vector<NodeIdType> *out_neighbors) override;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @return Status The status code returned
|
||||
Status AddNeighbor(const std::shared_ptr<Node> &node) override;
|
||||
Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
|
@ -69,12 +71,16 @@ class LocalNode : public Node {
|
|||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out);
|
||||
Status GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out);
|
||||
|
||||
Status GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
|
||||
const std::vector<WeightType> &weights, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out);
|
||||
|
||||
std::mt19937 rnd_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_;
|
||||
std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
|
|
|
@ -28,6 +28,7 @@ namespace dataset {
|
|||
namespace gnn {
|
||||
using NodeType = int8_t;
|
||||
using NodeIdType = int32_t;
|
||||
using WeightType = float;
|
||||
|
||||
constexpr NodeIdType kDefaultNodeId = -1;
|
||||
|
||||
|
@ -36,7 +37,8 @@ class Node {
|
|||
// Constructor
|
||||
// @param NodeIdType id - node id
|
||||
// @param NodeType type - node type
|
||||
Node(NodeIdType id, NodeType type) : id_(id), type_(type) {}
|
||||
// @param WeightType type - node weight
|
||||
Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {}
|
||||
|
||||
virtual ~Node() = default;
|
||||
|
||||
|
@ -46,6 +48,9 @@ class Node {
|
|||
// @return NodeIdType - Returned node type
|
||||
NodeType type() const { return type_; }
|
||||
|
||||
// @return WeightType - Returned node weight
|
||||
WeightType weight() const { return weight_; }
|
||||
|
||||
// Get the feature of a node
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
|
@ -62,15 +67,16 @@ class Node {
|
|||
// Get the sampled neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired
|
||||
// @param SamplingStrategy strategy - Sampling strategy
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status The status code returned
|
||||
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
|
||||
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
|
||||
std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @return Status The status code returned
|
||||
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0;
|
||||
virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
|
@ -80,6 +86,7 @@ class Node {
|
|||
protected:
|
||||
NodeIdType id_;
|
||||
NodeType type_;
|
||||
WeightType weight_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
|
|
|
@ -71,6 +71,9 @@ enum class NormalizeForm {
|
|||
kNfkd,
|
||||
};
|
||||
|
||||
// Possible values for SamplingStrategy
|
||||
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
|
||||
|
||||
// convenience functions for 32bit int bitmask
|
||||
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
|
|||
from ..core import config
|
||||
from .cache_client import DatasetCache
|
||||
from .datasets import *
|
||||
from .graphdata import GraphData
|
||||
from .graphdata import GraphData, SamplingStrategy
|
||||
from .iterators import *
|
||||
from .samplers import *
|
||||
from .serializer_deserializer import compare, deserialize, serialize, show
|
||||
|
|
|
@ -18,10 +18,12 @@ and provides operations related to graph data.
|
|||
"""
|
||||
import atexit
|
||||
import time
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
from mindspore._c_dataengine import GraphDataClient
|
||||
from mindspore._c_dataengine import GraphDataServer
|
||||
from mindspore._c_dataengine import Tensor
|
||||
from mindspore._c_dataengine import SamplingStrategy as Sampling
|
||||
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
|
||||
|
@ -29,6 +31,17 @@ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_
|
|||
check_gnn_random_walk
|
||||
|
||||
|
||||
class SamplingStrategy(IntEnum):
|
||||
RANDOM = 0
|
||||
EDGE_WEIGHT = 1
|
||||
|
||||
|
||||
DE_C_INTER_SAMPLING_STRATEGY = {
|
||||
SamplingStrategy.RANDOM: Sampling.DE_SAMPLING_RANDOM,
|
||||
SamplingStrategy.EDGE_WEIGHT: Sampling.DE_SAMPLING_EDGE_WEIGHT,
|
||||
}
|
||||
|
||||
|
||||
class GraphData:
|
||||
"""
|
||||
Reads the graph dataset used for GNN training from the shared file and database.
|
||||
|
@ -86,7 +99,7 @@ class GraphData:
|
|||
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
|
||||
atexit.register(stop)
|
||||
try:
|
||||
while self._graph_data.is_stoped() is not True:
|
||||
while self._graph_data.is_stopped() is not True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Graph data server receives KeyboardInterrupt.")
|
||||
|
@ -185,7 +198,7 @@ class GraphData:
|
|||
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
|
||||
|
||||
@check_gnn_get_sampled_neighbors
|
||||
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
|
||||
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM):
|
||||
"""
|
||||
Get sampled neighbor information.
|
||||
|
||||
|
@ -199,6 +212,11 @@ class GraphData:
|
|||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop.
|
||||
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.
|
||||
strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM).
|
||||
It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT].
|
||||
|
||||
- SamplingStrategy.RANDOM, random sampling with replacement.
|
||||
- SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of neighbors.
|
||||
|
@ -215,10 +233,12 @@ class GraphData:
|
|||
TypeError: If `neighbor_nums` is not list or ndarray.
|
||||
TypeError: If `neighbor_types` is not list or ndarray.
|
||||
"""
|
||||
if not isinstance(strategy, SamplingStrategy):
|
||||
raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.")
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
return self._graph_data.get_sampled_neighbors(
|
||||
node_list, neighbor_nums, neighbor_types).as_array()
|
||||
node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array()
|
||||
|
||||
@check_gnn_get_neg_sampled_neighbors
|
||||
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
|
||||
|
@ -342,7 +362,7 @@ class GraphData:
|
|||
target_nodes (list[int]): Start node list in random walk
|
||||
meta_path (list[int]): node type for each walk step
|
||||
step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0).
|
||||
step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0).
|
||||
step_away_param (float, optional): in out hyper parameter in node2vec algorithm (Default = 1.0).
|
||||
default_node (int, optional): default node if no more neighbors found (Default = -1).
|
||||
A default value of -1 indicates that no node is given.
|
||||
|
||||
|
|
|
@ -1114,7 +1114,7 @@ def check_gnn_get_sampled_neighbors(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs)
|
||||
[node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
check_gnn_list_or_ndarray(node_list, 'node_list')
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ class GraphMapSchema:
|
|||
"second_id": {"type": "int64"},
|
||||
"third_id": {"type": "int64"},
|
||||
"type": {"type": "int32"},
|
||||
"weight": {"type": "float32"},
|
||||
"attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge
|
||||
"node_feature_index": {"type": "int32", "shape": [-1]},
|
||||
"edge_feature_index": {"type": "int32", "shape": [-1]}
|
||||
|
@ -91,8 +92,11 @@ class GraphMapSchema:
|
|||
logger.info("node cannot be None.")
|
||||
raise ValueError("node cannot be None.")
|
||||
|
||||
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
|
||||
"node_feature_index": []}
|
||||
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n',
|
||||
"type": node["type"], "node_feature_index": []}
|
||||
if "weight" in node:
|
||||
node_graph["weight"] = node["weight"]
|
||||
|
||||
for i in range(self.num_node_features):
|
||||
k = i + 1
|
||||
node_field_key = 'feature_' + str(k)
|
||||
|
@ -129,8 +133,11 @@ class GraphMapSchema:
|
|||
logger.info("edge cannot be None.")
|
||||
raise ValueError("edge cannot be None.")
|
||||
|
||||
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
|
||||
"type": edge["type"], "edge_feature_index": []}
|
||||
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0,
|
||||
"attribute": 'e', "type": edge["type"], "edge_feature_index": []}
|
||||
|
||||
if "weight" in edge:
|
||||
edge_graph["weight"] = edge["weight"]
|
||||
|
||||
for i in range(self.num_edge_features):
|
||||
k = i + 1
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
||||
|
@ -38,6 +39,60 @@ using namespace mindspore::dataset::gnn;
|
|||
class MindDataTestGNNGraph : public UT::Common {
|
||||
protected:
|
||||
MindDataTestGNNGraph() = default;
|
||||
|
||||
using NumNeighborsMap = std::map<NodeIdType, uint32_t>;
|
||||
using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>;
|
||||
void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) {
|
||||
auto shape_vec = neighbors->shape().AsVector();
|
||||
uint32_t num_members = 1;
|
||||
for (size_t i = 1; i < shape_vec.size(); ++i) {
|
||||
num_members *= shape_vec[i];
|
||||
}
|
||||
uint32_t index = 0;
|
||||
NodeIdType src_node = 0;
|
||||
for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>();
|
||||
++node_itr, ++index) {
|
||||
if (index % num_members == 0) {
|
||||
src_node = *node_itr;
|
||||
continue;
|
||||
}
|
||||
auto src_node_itr = node_neighbors.find(src_node);
|
||||
if (src_node_itr == node_neighbors.end()) {
|
||||
node_neighbors[src_node] = {{*node_itr, 1}};
|
||||
} else {
|
||||
auto nei_itr = src_node_itr->second.find(*node_itr);
|
||||
if (nei_itr == src_node_itr->second.end()) {
|
||||
src_node_itr->second[*node_itr] = 1;
|
||||
} else {
|
||||
src_node_itr->second[*node_itr] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights,
|
||||
float deviation_ratio = 0.1) {
|
||||
EXPECT_EQ(number_neighbors.size(), weights.size());
|
||||
int index = 0;
|
||||
uint32_t pre_num = 0;
|
||||
WeightType pre_weight = 1;
|
||||
for (auto neighbor : number_neighbors) {
|
||||
if (pre_num != 0) {
|
||||
float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]);
|
||||
float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second);
|
||||
float target_upper = target_ratio * (1 + deviation_ratio);
|
||||
float target_lower = target_ratio * (1 - deviation_ratio);
|
||||
MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio)
|
||||
<< " target_upper:" << std::to_string(target_upper)
|
||||
<< " target_lower:" << std::to_string(target_lower);
|
||||
EXPECT_LE(current_ratio, target_upper);
|
||||
EXPECT_GE(current_ratio, target_lower);
|
||||
}
|
||||
pre_num = neighbor.second;
|
||||
pre_weight = weights[index];
|
||||
++index;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
||||
|
@ -131,44 +186,75 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
|||
std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
|
||||
|
||||
std::shared_ptr<Tensor> neighbors;
|
||||
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
|
||||
{
|
||||
MS_LOG(INFO) << "Test random sampling.";
|
||||
NodeNeighborsMap number_neighbors;
|
||||
int count = 0;
|
||||
while (count < 1000) {
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
|
||||
ParsingNeighbors(neighbors, number_neighbors);
|
||||
++count;
|
||||
}
|
||||
CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1});
|
||||
}
|
||||
|
||||
{
|
||||
MS_LOG(INFO) << "Test edge weight sampling.";
|
||||
NodeNeighborsMap number_neighbors;
|
||||
int count = 0;
|
||||
while (count < 1000) {
|
||||
neighbors.reset();
|
||||
s =
|
||||
graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
|
||||
ParsingNeighbors(neighbors, number_neighbors);
|
||||
++count;
|
||||
}
|
||||
CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8});
|
||||
}
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]},
|
||||
SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
|
||||
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
|
||||
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]},
|
||||
SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]},
|
||||
SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors);
|
||||
s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
|
||||
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]},
|
||||
SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
|
||||
std::string::npos);
|
||||
|
||||
neighbors.reset();
|
||||
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors);
|
||||
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
|
||||
}
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -17,6 +17,7 @@ import pytest
|
|||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.engine import SamplingStrategy
|
||||
|
||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
|
||||
|
@ -97,7 +98,10 @@ def test_graphdata_getsampledneighbors():
|
|||
nodes = g.get_nodes_from_edges(edges)
|
||||
assert len(nodes) == 40
|
||||
neighbor = g.get_sampled_neighbors(
|
||||
np.unique(nodes[0:21, 0]), [2, 3], [2, 1])
|
||||
np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM)
|
||||
assert neighbor.shape == (10, 9)
|
||||
neighbor = g.get_sampled_neighbors(
|
||||
np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT)
|
||||
assert neighbor.shape == (10, 9)
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ from multiprocessing import Process
|
|||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.engine import SamplingStrategy
|
||||
|
||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
|
||||
|
@ -68,9 +69,9 @@ class GNNGraphDataset():
|
|||
neg_nodes = self.g.get_neg_sampled_neighbors(
|
||||
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
|
||||
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
|
||||
2, 2], neighbor_types=[2, 1])
|
||||
neg_nodes_neighbors = self.g.get_sampled_neighbors(
|
||||
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
|
||||
2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
|
||||
neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
|
||||
neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
|
||||
nodes_neighbors_features = self.g.get_node_feature(
|
||||
node_list=nodes_neighbors, feature_types=[2, 3])
|
||||
neg_neighbors_features = self.g.get_node_feature(
|
||||
|
|
Loading…
Reference in New Issue