gnn support weight sampling neighbors.

This commit is contained in:
heleiwang 2021-01-25 10:36:45 +08:00
parent 3708624a25
commit 2dc9ba761c
28 changed files with 282 additions and 76 deletions

View File

@ -65,9 +65,9 @@ PYBIND_REGISTER(
})
.def("get_sampled_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) {
std::vector<gnn::NodeType> neighbor_types, SamplingStrategy strategy) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &out));
return out;
})
.def("get_neg_sampled_neighbors",
@ -114,7 +114,14 @@ PYBIND_REGISTER(
return out;
}))
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); });
.def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); });
}));
PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) {
(void)py::enum_<SamplingStrategy>(*m, "SamplingStrategy", py::arithmetic())
.value("DE_SAMPLING_RANDOM", SamplingStrategy::kRandom)
.value("DE_SAMPLING_EDGE_WEIGHT", SamplingStrategy::kEdgeWeight)
.export_values();
}));
} // namespace dataset

View File

@ -71,6 +71,9 @@ enum class NormalizeForm {
kNfkd,
};
// Possible values for SamplingStrategy
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
// convenience functions for 32bit int bitmask
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }

View File

@ -35,10 +35,11 @@ class Edge {
// Constructor
// @param EdgeIdType id - edge id
// @param EdgeType type - edge type
// @param WeightType weight - edge weight
// @param std::shared_ptr<Node> src_node - source node
// @param std::shared_ptr<Node> dst_node - destination node
Edge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
: id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {}
Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
: id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {}
virtual ~Edge() = default;
@ -48,6 +49,9 @@ class Edge {
// @return NodeIdType - Returned edge type
EdgeType type() const { return type_; }
// @return WeightType - Returned edge weight
WeightType weight() const { return weight_; }
// Get the feature of a edge
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
@ -77,6 +81,7 @@ class Edge {
protected:
EdgeIdType id_;
EdgeType type_;
WeightType weight_;
std::shared_ptr<Node> src_node_;
std::shared_ptr<Node> dst_node_;
};

View File

@ -71,6 +71,7 @@ message GnnGraphDataRequestPb {
repeated int32 number = 4; // samples number
TensorPb id_tensor = 5; // input ids ,node id or edge id
GnnRandomWalkPb random_walk = 6;
int32 strategy = 7;
}
message GnnGraphDataResponsePb {

View File

@ -76,11 +76,13 @@ class GraphData {
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::SamplingStrategy strategy - Sampling strategy
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status The status code returned
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) = 0;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
@ -95,7 +97,7 @@ class GraphData {
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param float step_away_param - in out hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status The status code returned

View File

@ -137,7 +137,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list
Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
@ -151,6 +152,7 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_
for (const auto &type : neighbor_types) {
request.add_type(static_cast<google::protobuf::int32>(type));
}
request.set_strategy(static_cast<google::protobuf::int32>(strategy));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();

View File

@ -86,10 +86,12 @@ class GraphDataClient : public GraphData {
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::SamplingStrategy strategy - Sampling strategy
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
@ -104,7 +106,7 @@ class GraphDataClient : public GraphData {
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param float step_away_param - in out hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status The status code returned

View File

@ -171,7 +171,8 @@ Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) {
Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
@ -199,7 +200,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
std::shared_ptr<Node> node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node));
std::vector<NodeIdType> out;
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], &out));
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out));
neighbors.insert(neighbors.end(), out.begin(), out.end());
}
}

View File

@ -80,10 +80,12 @@ class GraphDataImpl : public GraphData {
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::SamplingStrategy strategy - Sampling strategy
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status The status code returned
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
const std::vector<NodeType> &neighbor_types, SamplingStrategy strategy,
std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
@ -98,7 +100,7 @@ class GraphDataImpl : public GraphData {
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param float step_away_param - in out hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status The status code returned
@ -194,7 +196,7 @@ class GraphDataImpl : public GraphData {
std::vector<NodeIdType> node_list_;
std::vector<NodeType> meta_path_;
float step_home_param_; // Return hyper parameter. Default is 1.0
float step_away_param_; // Inout hyper parameter. Default is 1.0
float step_away_param_; // In out hyper parameter. Default is 1.0
NodeIdType default_node_;
int32_t num_walks_; // Number of walks per source. Default is 1

View File

@ -50,7 +50,7 @@ class GraphDataServer {
enum ServerState state() { return state_; }
bool IsStoped() {
bool IsStopped() {
if (state_ == kGdsStopped) {
return true;
} else {

View File

@ -78,7 +78,7 @@ grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context,
}
break;
case GraphDataServer::kGdsStopped:
response->set_error_msg("Stoped");
response->set_error_msg("Stopped");
break;
}
} else {
@ -222,8 +222,9 @@ Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *re
neighbor_types.resize(request->type().size());
std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(),
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
SamplingStrategy strategy = static_cast<SamplingStrategy>(request->strategy());
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor));
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, strategy, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();

View File

@ -39,7 +39,9 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int
row_id_(0),
shard_reader_(nullptr),
graph_feature_parser_(nullptr),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
required_key_(
{"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}),
optional_key_({{"weight", false}}) {}
Status GraphLoader::GetNodesAndEdges() {
NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
@ -62,7 +64,7 @@ Status GraphLoader::GetNodesAndEdges() {
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight()));
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
dq.pop_front();
@ -95,12 +97,18 @@ Status GraphLoader::InitAndLoad() {
graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
mindrecord::json schema = graph_impl_->data_schema_["schema"];
for (const std::string &key : keys_) {
for (const std::string &key : required_key_) {
if (schema.find(key) == schema.end()) {
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
}
}
for (auto op_key : optional_key_) {
if (schema.find(op_key.first) != schema.end()) {
optional_key_[op_key.first] = true;
}
}
if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64)
int64_t total_blob_size = 0;
@ -128,7 +136,11 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
DefaultNodeFeatureMap *default_feature) {
NodeIdType node_id = col_jsn["first_id"];
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
(*node) = std::make_shared<LocalNode>(node_id, node_type);
WeightType weight = 1;
if (optional_key_["weight"]) {
weight = col_jsn["weight"];
}
(*node) = std::make_shared<LocalNode>(node_id, node_type, weight);
std::vector<int32_t> indices;
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) {
@ -174,9 +186,13 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
EdgeIdType edge_id = col_jsn["first_id"];
EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
WeightType edge_weight = 1;
if (optional_key_["weight"]) {
edge_weight = col_jsn["weight"];
}
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1, 1);
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1, 1);
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, edge_weight, src, dst);
std::vector<int32_t> indices;
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) {

View File

@ -110,7 +110,8 @@ class GraphLoader {
std::vector<EdgeFeatureMap> e_feature_maps_;
std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
const std::vector<std::string> keys_;
const std::vector<std::string> required_key_;
std::unordered_map<std::string, bool> optional_key_;
};
} // namespace gnn
} // namespace dataset

View File

@ -21,8 +21,9 @@ namespace mindspore {
namespace dataset {
namespace gnn {
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
: Edge(id, type, src_node, dst_node) {}
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
std::shared_ptr<Node> dst_node)
: Edge(id, type, weight, src_node, dst_node) {}
Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = features_.find(feature_type);

View File

@ -34,9 +34,11 @@ class LocalEdge : public Edge {
// Constructor
// @param EdgeIdType id - edge id
// @param EdgeType type - edge type
// @param WeightType weight - edge weight
// @param std::shared_ptr<Node> src_node - source node
// @param std::shared_ptr<Node> dst_node - destination node
LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node);
LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
std::shared_ptr<Node> dst_node);
~LocalEdge() = default;

View File

@ -16,6 +16,7 @@
#include "minddata/dataset/engine/gnn/local_node.h"
#include <algorithm>
#include <random>
#include <string>
#include <utility>
@ -26,7 +27,10 @@ namespace mindspore {
namespace dataset {
namespace gnn {
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type), rnd_(GetRandomDevice()) { rnd_.seed(GetSeed()); }
LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight)
: Node(id, type, weight), rnd_(GetRandomDevice()) {
rnd_.seed(GetSeed());
}
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = features_.find(feature_type);
@ -44,13 +48,13 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) {
if (exclude_itself) {
neighbors.resize(itr->second.size());
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(),
neighbors.resize(itr->second.first.size());
std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin(),
[](const std::shared_ptr<Node> node) { return node->id(); });
} else {
neighbors.resize(itr->second.size() + 1);
neighbors.resize(itr->second.first.size() + 1);
neighbors[0] = id_;
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
std::transform(itr->second.first.begin(), itr->second.first.end(), neighbors.begin() + 1,
[](const std::shared_ptr<Node> node) { return node->id(); });
}
} else {
@ -63,7 +67,7 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
return Status::OK();
}
Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
std::vector<NodeIdType> *out) {
std::vector<NodeIdType> shuffled_id(neighbors.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
@ -75,14 +79,33 @@ Status LocalNode::GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &
return Status::OK();
}
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
const std::vector<WeightType> &weights, int32_t samples_num,
std::vector<NodeIdType> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(),
"The number of neighbors does not match the weight.");
std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end());
for (int32_t i = 0; i < samples_num; ++i) {
NodeIdType index = discrete_dist(rnd_);
out->emplace_back(neighbors[index]->id());
}
return Status::OK();
}
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
std::vector<NodeIdType> *out_neighbors) {
std::vector<NodeIdType> neighbors;
neighbors.reserve(samples_num);
auto itr = neighbor_nodes_.find(neighbor_type);
if (itr != neighbor_nodes_.end()) {
if (strategy == SamplingStrategy::kRandom) {
while (neighbors.size() < samples_num) {
RETURN_IF_NOT_OK(GetSampledNeighbors(itr->second, samples_num - neighbors.size(), &neighbors));
RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors));
}
} else if (strategy == SamplingStrategy::kEdgeWeight) {
RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors));
} else {
RETURN_STATUS_UNEXPECTED("Invalid strategy");
}
} else {
MS_LOG(DEBUG) << "There are no neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
@ -95,12 +118,15 @@ Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_nu
return Status::OK();
}
Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node) {
Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) {
auto itr = neighbor_nodes_.find(node->type());
if (itr != neighbor_nodes_.end()) {
itr->second.push_back(node);
itr->second.first.push_back(node);
itr->second.second.push_back(weight);
} else {
neighbor_nodes_[node->type()] = {node};
std::vector<std::shared_ptr<Node>> nodes = {node};
std::vector<WeightType> weights = {weight};
neighbor_nodes_[node->type()] = std::make_pair(std::move(nodes), std::move(weights));
}
return Status::OK();
}

View File

@ -18,6 +18,7 @@
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/gnn/node.h"
@ -33,7 +34,7 @@ class LocalNode : public Node {
// Constructor
// @param NodeIdType id - node id
// @param NodeType type - node type
LocalNode(NodeIdType id, NodeType type);
LocalNode(NodeIdType id, NodeType type, WeightType weight);
~LocalNode() = default;
@ -53,15 +54,16 @@ class LocalNode : public Node {
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param SamplingStrategy strategy - Sampling strategy
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status The status code returned
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
std::vector<NodeIdType> *out_neighbors) override;
// Add neighbor of node
// @param std::shared_ptr<Node> node -
// @return Status The status code returned
Status AddNeighbor(const std::shared_ptr<Node> &node) override;
Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &) override;
// Update feature of node
// @param std::shared_ptr<Feature> feature -
@ -69,12 +71,16 @@ class LocalNode : public Node {
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
private:
Status GetSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
Status GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
std::vector<NodeIdType> *out);
Status GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
const std::vector<WeightType> &weights, int32_t samples_num,
std::vector<NodeIdType> *out);
std::mt19937 rnd_;
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_;
std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
};
} // namespace gnn
} // namespace dataset

View File

@ -28,6 +28,7 @@ namespace dataset {
namespace gnn {
using NodeType = int8_t;
using NodeIdType = int32_t;
using WeightType = float;
constexpr NodeIdType kDefaultNodeId = -1;
@ -36,7 +37,8 @@ class Node {
// Constructor
// @param NodeIdType id - node id
// @param NodeType type - node type
Node(NodeIdType id, NodeType type) : id_(id), type_(type) {}
// @param WeightType type - node weight
Node(NodeIdType id, NodeType type, WeightType weight) : id_(id), type_(type), weight_(weight) {}
virtual ~Node() = default;
@ -46,6 +48,9 @@ class Node {
// @return NodeIdType - Returned node type
NodeType type() const { return type_; }
// @return WeightType - Returned node weight
WeightType weight() const { return weight_; }
// Get the feature of a node
// @param FeatureType feature_type - type of feature
// @param std::shared_ptr<Feature> *out_feature - Returned feature
@ -62,15 +67,16 @@ class Node {
// Get the sampled neighbors of a node
// @param NodeType neighbor_type - type of neighbor
// @param int32_t samples_num - Number of neighbors to be acquired
// @param SamplingStrategy strategy - Sampling strategy
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
// @return Status The status code returned
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num,
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
std::vector<NodeIdType> *out_neighbors) = 0;
// Add neighbor of node
// @param std::shared_ptr<Node> node -
// @return Status The status code returned
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0;
virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0;
// Update feature of node
// @param std::shared_ptr<Feature> feature -
@ -80,6 +86,7 @@ class Node {
protected:
NodeIdType id_;
NodeType type_;
WeightType weight_;
};
} // namespace gnn
} // namespace dataset

View File

@ -71,6 +71,9 @@ enum class NormalizeForm {
kNfkd,
};
// Possible values for SamplingStrategy
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
// convenience functions for 32bit int bitmask
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }

View File

@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
from ..core import config
from .cache_client import DatasetCache
from .datasets import *
from .graphdata import GraphData
from .graphdata import GraphData, SamplingStrategy
from .iterators import *
from .samplers import *
from .serializer_deserializer import compare, deserialize, serialize, show

View File

@ -18,10 +18,12 @@ and provides operations related to graph data.
"""
import atexit
import time
from enum import IntEnum
import numpy as np
from mindspore._c_dataengine import GraphDataClient
from mindspore._c_dataengine import GraphDataServer
from mindspore._c_dataengine import Tensor
from mindspore._c_dataengine import SamplingStrategy as Sampling
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
@ -29,6 +31,17 @@ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_
check_gnn_random_walk
class SamplingStrategy(IntEnum):
RANDOM = 0
EDGE_WEIGHT = 1
DE_C_INTER_SAMPLING_STRATEGY = {
SamplingStrategy.RANDOM: Sampling.DE_SAMPLING_RANDOM,
SamplingStrategy.EDGE_WEIGHT: Sampling.DE_SAMPLING_EDGE_WEIGHT,
}
class GraphData:
"""
Reads the graph dataset used for GNN training from the shared file and database.
@ -86,7 +99,7 @@ class GraphData:
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
atexit.register(stop)
try:
while self._graph_data.is_stoped() is not True:
while self._graph_data.is_stopped() is not True:
time.sleep(1)
except KeyboardInterrupt:
raise Exception("Graph data server receives KeyboardInterrupt.")
@ -185,7 +198,7 @@ class GraphData:
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
@check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM):
"""
Get sampled neighbor information.
@ -199,6 +212,11 @@ class GraphData:
node_list (Union[list, numpy.ndarray]): The given list of nodes.
neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop.
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.
strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM).
It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT].
- SamplingStrategy.RANDOM, random sampling with replacement.
- SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability.
Returns:
numpy.ndarray, array of neighbors.
@ -215,10 +233,12 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray.
"""
if not isinstance(strategy, SamplingStrategy):
raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.")
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array()
node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array()
@check_gnn_get_neg_sampled_neighbors
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
@ -342,7 +362,7 @@ class GraphData:
target_nodes (list[int]): Start node list in random walk
meta_path (list[int]): node type for each walk step
step_home_param (float, optional): return hyper parameter in node2vec algorithm (Default = 1.0).
step_away_param (float, optional): inout hyper parameter in node2vec algorithm (Default = 1.0).
step_away_param (float, optional): in out hyper parameter in node2vec algorithm (Default = 1.0).
default_node (int, optional): default node if no more neighbors found (Default = -1).
A default value of -1 indicates that no node is given.

View File

@ -1114,7 +1114,7 @@ def check_gnn_get_sampled_neighbors(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list, neighbor_nums, neighbor_types], _ = parse_user_args(method, *args, **kwargs)
[node_list, neighbor_nums, neighbor_types, _], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(node_list, 'node_list')

View File

@ -37,6 +37,7 @@ class GraphMapSchema:
"second_id": {"type": "int64"},
"third_id": {"type": "int64"},
"type": {"type": "int32"},
"weight": {"type": "float32"},
"attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge
"node_feature_index": {"type": "int32", "shape": [-1]},
"edge_feature_index": {"type": "int32", "shape": [-1]}
@ -91,8 +92,11 @@ class GraphMapSchema:
logger.info("node cannot be None.")
raise ValueError("node cannot be None.")
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
"node_feature_index": []}
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "weight": 1.0, "attribute": 'n',
"type": node["type"], "node_feature_index": []}
if "weight" in node:
node_graph["weight"] = node["weight"]
for i in range(self.num_node_features):
k = i + 1
node_field_key = 'feature_' + str(k)
@ -129,8 +133,11 @@ class GraphMapSchema:
logger.info("edge cannot be None.")
raise ValueError("edge cannot be None.")
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
"type": edge["type"], "edge_feature_index": []}
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "weight": 1.0,
"attribute": 'e', "type": edge["type"], "edge_feature_index": []}
if "weight" in edge:
edge_graph["weight"] = edge["weight"]
for i in range(self.num_edge_features):
k = i + 1

View File

@ -15,6 +15,7 @@
*/
#include <algorithm>
#include <string>
#include <map>
#include <memory>
#include <unordered_set>
@ -38,6 +39,60 @@ using namespace mindspore::dataset::gnn;
class MindDataTestGNNGraph : public UT::Common {
protected:
MindDataTestGNNGraph() = default;
using NumNeighborsMap = std::map<NodeIdType, uint32_t>;
using NodeNeighborsMap = std::map<NodeIdType, NumNeighborsMap>;
void ParsingNeighbors(const std::shared_ptr<Tensor> &neighbors, NodeNeighborsMap &node_neighbors) {
auto shape_vec = neighbors->shape().AsVector();
uint32_t num_members = 1;
for (size_t i = 1; i < shape_vec.size(); ++i) {
num_members *= shape_vec[i];
}
uint32_t index = 0;
NodeIdType src_node = 0;
for (auto node_itr = neighbors->begin<NodeIdType>(); node_itr != neighbors->end<NodeIdType>();
++node_itr, ++index) {
if (index % num_members == 0) {
src_node = *node_itr;
continue;
}
auto src_node_itr = node_neighbors.find(src_node);
if (src_node_itr == node_neighbors.end()) {
node_neighbors[src_node] = {{*node_itr, 1}};
} else {
auto nei_itr = src_node_itr->second.find(*node_itr);
if (nei_itr == src_node_itr->second.end()) {
src_node_itr->second[*node_itr] = 1;
} else {
src_node_itr->second[*node_itr] += 1;
}
}
}
}
void CheckNeighborsRatio(const NumNeighborsMap &number_neighbors, const std::vector<WeightType> &weights,
float deviation_ratio = 0.1) {
EXPECT_EQ(number_neighbors.size(), weights.size());
int index = 0;
uint32_t pre_num = 0;
WeightType pre_weight = 1;
for (auto neighbor : number_neighbors) {
if (pre_num != 0) {
float target_ratio = static_cast<float>(pre_weight) / static_cast<float>(weights[index]);
float current_ratio = static_cast<float>(pre_num) / static_cast<float>(neighbor.second);
float target_upper = target_ratio * (1 + deviation_ratio);
float target_lower = target_ratio * (1 - deviation_ratio);
MS_LOG(INFO) << "current_ratio:" << std::to_string(current_ratio)
<< " target_upper:" << std::to_string(target_upper)
<< " target_lower:" << std::to_string(target_lower);
EXPECT_LE(current_ratio, target_upper);
EXPECT_GE(current_ratio, target_lower);
}
pre_num = neighbor.second;
pre_weight = weights[index];
++index;
}
}
};
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
@ -131,44 +186,75 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::transform(node_set.begin(), node_set.end(), node_list.begin(), [](const NodeIdType node) { return node; });
std::shared_ptr<Tensor> neighbors;
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, &neighbors);
{
MS_LOG(INFO) << "Test random sampling.";
NodeNeighborsMap number_neighbors;
int count = 0;
while (count < 1000) {
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
ParsingNeighbors(neighbors, number_neighbors);
++count;
}
CheckNeighborsRatio(number_neighbors[103], {1, 1, 1, 1, 1});
}
{
MS_LOG(INFO) << "Test edge weight sampling.";
NodeNeighborsMap number_neighbors;
int count = 0;
while (count < 1000) {
neighbors.reset();
s =
graph.GetSampledNeighbors(node_list, {10}, {meta_info.node_type[1]}, SamplingStrategy::kEdgeWeight, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,11>");
ParsingNeighbors(neighbors, number_neighbors);
++count;
}
CheckNeighborsRatio(number_neighbors[103], {3, 5, 6, 7, 8});
}
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
s = graph.GetSampledNeighbors(node_list, {2, 3}, {meta_info.node_type[1], meta_info.node_type[0]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,9>");
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4},
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
{meta_info.node_type[1], meta_info.node_type[0], meta_info.node_type[1]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<5,33>");
neighbors.reset();
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, &neighbors);
s = graph.GetSampledNeighbors({}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Input node_list is empty.") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, &neighbors);
s = graph.GetSampledNeighbors({-1, 1}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]}, &neighbors);
s = graph.GetSampledNeighbors(node_list, {2, 50}, {meta_info.node_type[0], meta_info.node_type[1]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Wrong samples number") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2}, {5}, &neighbors);
s = graph.GetSampledNeighbors(node_list, {2}, {5}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid neighbor type") != std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]}, &neighbors);
s = graph.GetSampledNeighbors(node_list, {2, 3, 4}, {meta_info.node_type[1], meta_info.node_type[0]},
SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("The sizes of neighbor_nums and neighbor_types are inconsistent.") !=
std::string::npos);
neighbors.reset();
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, &neighbors);
s = graph.GetSampledNeighbors({301}, {10}, {meta_info.node_type[1]}, SamplingStrategy::kRandom, &neighbors);
EXPECT_TRUE(s.ToString().find("Invalid node id:301") != std::string::npos);
}

View File

@ -17,6 +17,7 @@ import pytest
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.dataset.engine import SamplingStrategy
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
@ -97,7 +98,10 @@ def test_graphdata_getsampledneighbors():
nodes = g.get_nodes_from_edges(edges)
assert len(nodes) == 40
neighbor = g.get_sampled_neighbors(
np.unique(nodes[0:21, 0]), [2, 3], [2, 1])
np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.RANDOM)
assert neighbor.shape == (10, 9)
neighbor = g.get_sampled_neighbors(
np.unique(nodes[0:21, 0]), [2, 3], [2, 1], SamplingStrategy.EDGE_WEIGHT)
assert neighbor.shape == (10, 9)

View File

@ -20,6 +20,7 @@ from multiprocessing import Process
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.dataset.engine import SamplingStrategy
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
@ -68,9 +69,9 @@ class GNNGraphDataset():
neg_nodes = self.g.get_neg_sampled_neighbors(
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
2, 2], neighbor_types=[2, 1])
neg_nodes_neighbors = self.g.get_sampled_neighbors(
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
nodes_neighbors_features = self.g.get_node_feature(
node_list=nodes_neighbors, feature_types=[2, 3])
neg_neighbors_features = self.g.get_node_feature(