random walk v1
This commit is contained in:
parent
86253c342a
commit
87d2c27c7f
|
@ -0,0 +1,81 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
|
||||
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
|
||||
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
|
||||
[348, 347], [347, 351], [347, 327], [347, 329], [347, 331],
|
||||
[347, 335], [347, 341], [347, 345], [347, 346], [346, 335],
|
||||
[346, 340], [346, 339], [346, 349], [346, 353], [346, 354],
|
||||
[346, 341], [346, 345], [345, 335], [345, 336], [345, 341],
|
||||
[344, 338], [344, 342], [343, 332], [343, 338], [343, 342],
|
||||
[342, 332], [340, 349], [334, 349], [333, 349], [330, 349],
|
||||
[328, 349], [359, 349], [358, 352], [358, 349], [358, 354],
|
||||
[358, 356], [357, 350], [357, 354], [357, 356], [356, 350],
|
||||
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (0, [], [])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
node_list = []
|
||||
for edge in social_data:
|
||||
src, dst = edge
|
||||
if src not in node_list:
|
||||
node_list.append(src)
|
||||
if dst not in node_list:
|
||||
node_list.append(dst)
|
||||
node_list.sort()
|
||||
print(node_list)
|
||||
for node_id in node_list:
|
||||
node = {'id': node_id, 'type': 1}
|
||||
yield node
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
"""
|
||||
Generate edge data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
line_count = 0
|
||||
for undirected_edge in social_data:
|
||||
line_count += 1
|
||||
edge = {
|
||||
'id': line_count,
|
||||
'src_id': undirected_edge[0],
|
||||
'dst_id': undirected_edge[1],
|
||||
'type': 1}
|
||||
yield edge
|
||||
line_count += 1
|
||||
edge = {
|
||||
'id': line_count,
|
||||
'src_id': undirected_edge[1],
|
||||
'dst_id': undirected_edge[0],
|
||||
'type': 1}
|
||||
yield edge
|
|
@ -0,0 +1,10 @@
|
|||
#!/bin/bash
|
||||
MINDRECORD_PATH=/tmp/sns
|
||||
|
||||
rm -f $MINDRECORD_PATH/*
|
||||
|
||||
python writer.py --mindrecord_script sns \
|
||||
--mindrecord_file "$MINDRECORD_PATH/sns" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 14 \
|
||||
--mindrecord_page_size_by_bit 15
|
|
@ -584,9 +584,16 @@ void bindGraphData(py::module *m) {
|
|||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||
return out;
|
||||
})
|
||||
.def("graph_info", [](gnn::Graph &g) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(g.GraphInfo(&out));
|
||||
.def("graph_info",
|
||||
[](gnn::Graph &g) {
|
||||
py::dict out;
|
||||
THROW_IF_ERROR(g.GraphInfo(&out));
|
||||
return out;
|
||||
})
|
||||
.def("random_walk", [](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
|
||||
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
||||
return out;
|
||||
});
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace dataset {
|
|||
namespace gnn {
|
||||
|
||||
Graph::Graph(std::string dataset_file, int32_t num_workers)
|
||||
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()) {
|
||||
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) {
|
||||
rnd_.seed(GetSeed());
|
||||
MS_LOG(INFO) << "num_workers:" << num_workers;
|
||||
}
|
||||
|
@ -240,8 +240,13 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p,
|
||||
float q, NodeIdType default_node, std::shared_ptr<Tensor> *out) {
|
||||
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
|
||||
std::vector<std::vector<NodeIdType>> walks;
|
||||
RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({walks}, DataType(DataType::DE_INT32), out));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -386,6 +391,195 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Graph::RandomWalkBase::RandomWalkBase(Graph *graph)
|
||||
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
|
||||
|
||||
Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, const NodeIdType default_node,
|
||||
int32_t num_walks, int32_t num_workers) {
|
||||
node_list_ = node_list;
|
||||
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
|
||||
std::string err_msg = "Failed, meta path required between 1 and " + std::to_string(kMaxNumWalks) +
|
||||
". The size of input path is " + std::to_string(meta_path.size());
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
meta_path_ = meta_path;
|
||||
if (step_home_param < kGnnEpsilon || step_away_param < kGnnEpsilon) {
|
||||
std::string err_msg = "Failed, step_home_param and step_away_param required greater than " +
|
||||
std::to_string(kGnnEpsilon) + ". step_home_param: " + std::to_string(step_home_param) +
|
||||
", step_away_param: " + std::to_string(step_away_param);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
step_home_param_ = step_home_param;
|
||||
step_away_param_ = step_away_param;
|
||||
default_node_ = default_node;
|
||||
num_walks_ = num_walks;
|
||||
num_workers_ = num_workers;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
|
||||
// Simulate a random walk starting from start node.
|
||||
auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
|
||||
// walk simulate
|
||||
while (walk.size() - 1 < meta_path_.size()) {
|
||||
// current nodE
|
||||
auto cur_node_id = walk.back();
|
||||
std::shared_ptr<Node> cur_node;
|
||||
RETURN_IF_NOT_OK(graph_->GetNodeByNodeId(cur_node_id, &cur_node));
|
||||
|
||||
// current neighbors
|
||||
std::vector<NodeIdType> cur_neighbors;
|
||||
RETURN_IF_NOT_OK(cur_node->GetAllNeighbors(meta_path_[walk.size() - 1], &cur_neighbors, true));
|
||||
std::sort(cur_neighbors.begin(), cur_neighbors.end());
|
||||
|
||||
// break if no neighbors
|
||||
if (cur_neighbors.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// walk by the fist node, then by the previous 2 nodes
|
||||
std::shared_ptr<StochasticIndex> stochastic_index;
|
||||
if (walk.size() == 1) {
|
||||
RETURN_IF_NOT_OK(GetNodeProbability(cur_node_id, meta_path_[0], &stochastic_index));
|
||||
} else {
|
||||
NodeIdType prev_node_id = walk[walk.size() - 2];
|
||||
RETURN_IF_NOT_OK(GetEdgeProbability(prev_node_id, cur_node_id, walk.size() - 2, &stochastic_index));
|
||||
}
|
||||
NodeIdType next_node_id = cur_neighbors[WalkToNextNode(*stochastic_index)];
|
||||
walk.push_back(next_node_id);
|
||||
}
|
||||
|
||||
while (walk.size() - 1 < meta_path_.size()) {
|
||||
walk.push_back(default_node_);
|
||||
}
|
||||
|
||||
*walk_path = std::move(walk);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
|
||||
// Repeatedly simulate random walks from each node
|
||||
std::vector<uint32_t> permutation(node_list_.size());
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
for (int32_t i = 0; i < num_walks_; i++) {
|
||||
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
|
||||
std::shuffle(permutation.begin(), permutation.end(), std::default_random_engine(seed));
|
||||
for (const auto &i_perm : permutation) {
|
||||
std::vector<NodeIdType> walk;
|
||||
RETURN_IF_NOT_OK(Node2vecWalk(node_list_[i_perm], &walk));
|
||||
walks->push_back(walk);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
|
||||
std::shared_ptr<StochasticIndex> *node_probability) {
|
||||
// Generate alias nodes
|
||||
std::shared_ptr<Node> node;
|
||||
graph_->GetNodeByNodeId(node_id, &node);
|
||||
std::vector<NodeIdType> neighbors;
|
||||
RETURN_IF_NOT_OK(node->GetAllNeighbors(node_type, &neighbors, true));
|
||||
std::sort(neighbors.begin(), neighbors.end());
|
||||
auto non_normalized_probability = std::vector<float>(neighbors.size(), 1.0);
|
||||
*node_probability =
|
||||
std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
|
||||
std::shared_ptr<StochasticIndex> *edge_probability) {
|
||||
// Get the alias edge setup lists for a given edge.
|
||||
std::shared_ptr<Node> src_node;
|
||||
graph_->GetNodeByNodeId(src, &src_node);
|
||||
std::vector<NodeIdType> src_neighbors;
|
||||
RETURN_IF_NOT_OK(src_node->GetAllNeighbors(meta_path_[meta_path_index], &src_neighbors, true));
|
||||
|
||||
std::shared_ptr<Node> dst_node;
|
||||
graph_->GetNodeByNodeId(dst, &dst_node);
|
||||
std::vector<NodeIdType> dst_neighbors;
|
||||
RETURN_IF_NOT_OK(dst_node->GetAllNeighbors(meta_path_[meta_path_index + 1], &dst_neighbors, true));
|
||||
|
||||
std::sort(dst_neighbors.begin(), dst_neighbors.end());
|
||||
std::vector<float> non_normalized_probability;
|
||||
for (const auto &dst_nbr : dst_neighbors) {
|
||||
if (dst_nbr == src) {
|
||||
non_normalized_probability.push_back(1.0 / step_home_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
|
||||
continue;
|
||||
}
|
||||
auto it = std::find(src_neighbors.begin(), src_neighbors.end(), dst_nbr);
|
||||
if (it != src_neighbors.end()) {
|
||||
// stay close, this node connect both src and dst
|
||||
non_normalized_probability.push_back(1.0); // replace 1.0 with G[dst][dst_nbr]['weight']
|
||||
} else {
|
||||
// step far away
|
||||
non_normalized_probability.push_back(1.0 / step_away_param_); // replace 1.0 with G[dst][dst_nbr]['weight']
|
||||
}
|
||||
}
|
||||
|
||||
*edge_probability =
|
||||
std::make_shared<StochasticIndex>(GenerateProbability(Normalize<float>(non_normalized_probability)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
|
||||
uint32_t K = probability.size();
|
||||
std::vector<int32_t> switch_to_large_index(K, 0);
|
||||
std::vector<float> weight(K, .0);
|
||||
std::vector<int32_t> smaller;
|
||||
std::vector<int32_t> larger;
|
||||
auto random_device = GetRandomDevice();
|
||||
std::uniform_real_distribution<> distribution(-kGnnEpsilon, kGnnEpsilon);
|
||||
float accumulate_threshold = 0.0;
|
||||
for (uint32_t i = 0; i < K; i++) {
|
||||
float threshold_one = distribution(random_device);
|
||||
accumulate_threshold += threshold_one;
|
||||
weight[i] = i < K - 1 ? probability[i] * K + threshold_one : probability[i] * K - accumulate_threshold;
|
||||
weight[i] < 1.0 ? smaller.push_back(i) : larger.push_back(i);
|
||||
}
|
||||
|
||||
while ((!smaller.empty()) && (!larger.empty())) {
|
||||
uint32_t small = smaller.back();
|
||||
smaller.pop_back();
|
||||
uint32_t large = larger.back();
|
||||
larger.pop_back();
|
||||
switch_to_large_index[small] = large;
|
||||
weight[large] = weight[large] + weight[small] - 1.0;
|
||||
weight[large] < 1.0 ? smaller.push_back(large) : larger.push_back(large);
|
||||
}
|
||||
return StochasticIndex(switch_to_large_index, weight);
|
||||
}
|
||||
|
||||
uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
|
||||
auto switch_to_large_index = stochastic_index.first;
|
||||
auto weight = stochastic_index.second;
|
||||
const uint32_t size_of_index = switch_to_large_index.size();
|
||||
|
||||
auto random_device = GetRandomDevice();
|
||||
std::uniform_real_distribution<> distribution(0.0, 1.0);
|
||||
|
||||
// Generate random integer between [0, K)
|
||||
uint32_t random_idx = std::floor(distribution(random_device) * size_of_index);
|
||||
|
||||
if (distribution(random_device) < weight[random_idx]) {
|
||||
return random_idx;
|
||||
}
|
||||
return switch_to_large_index[random_idx];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
|
||||
float sum_probability =
|
||||
1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
|
||||
if (sum_probability < kGnnEpsilon) {
|
||||
sum_probability = 1.0;
|
||||
}
|
||||
std::vector<float> normalized_probability;
|
||||
std::transform(non_normalized_probability.begin(), non_normalized_probability.end(),
|
||||
std::back_inserter(normalized_probability), [&](T value) -> float { return value / sum_probability; });
|
||||
return normalized_probability;
|
||||
}
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,12 +16,14 @@
|
|||
#ifndef DATASET_ENGINE_GNN_GRAPH_H_
|
||||
#define DATASET_ENGINE_GNN_GRAPH_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/gnn/graph_loader.h"
|
||||
|
@ -34,6 +36,10 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
const float kGnnEpsilon = 0.0001;
|
||||
const uint32_t kMaxNumWalks = 80;
|
||||
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
|
||||
|
||||
struct MetaInfo {
|
||||
std::vector<NodeType> node_type;
|
||||
std::vector<EdgeType> edge_type;
|
||||
|
@ -98,8 +104,17 @@ class Graph {
|
|||
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
|
||||
NodeIdType default_node, std::shared_ptr<Tensor> *out);
|
||||
// Node2vec random walk.
|
||||
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||
// @param std::vector<NodeType> meta_path - node type of each step
|
||||
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||
// @param NodeIdType default_node - default node id
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||
// @return Status - The error code return
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||
std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
|
@ -130,6 +145,45 @@ class Graph {
|
|||
Status Init();
|
||||
|
||||
private:
|
||||
class RandomWalkBase {
|
||||
public:
|
||||
explicit RandomWalkBase(Graph *graph);
|
||||
|
||||
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
|
||||
int32_t num_walks = 1, int32_t num_workers = 1);
|
||||
|
||||
~RandomWalkBase() = default;
|
||||
|
||||
Status SimulateWalk(std::vector<std::vector<NodeIdType>> *walks);
|
||||
|
||||
private:
|
||||
Status Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path);
|
||||
|
||||
Status GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
|
||||
std::shared_ptr<StochasticIndex> *node_probability);
|
||||
|
||||
Status GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
|
||||
std::shared_ptr<StochasticIndex> *edge_probability);
|
||||
|
||||
static StochasticIndex GenerateProbability(const std::vector<float> &probability);
|
||||
|
||||
static uint32_t WalkToNextNode(const StochasticIndex &stochastic_index);
|
||||
|
||||
template <typename T>
|
||||
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
|
||||
|
||||
Graph *graph_;
|
||||
std::vector<NodeIdType> node_list_;
|
||||
std::vector<NodeType> meta_path_;
|
||||
float step_home_param_; // Return hyper parameter. Default is 1.0
|
||||
float step_away_param_; // Inout hyper parameter. Default is 1.0
|
||||
NodeIdType default_node_;
|
||||
|
||||
int32_t num_walks_; // Number of walks per source. Default is 10
|
||||
int32_t num_workers_; // The number of worker threads. Default is 1
|
||||
};
|
||||
|
||||
// Load graph data from mindrecord file
|
||||
// @return Status - The error code return
|
||||
Status LoadNodeAndEdge();
|
||||
|
@ -174,6 +228,7 @@ class Graph {
|
|||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
std::mt19937 rnd_;
|
||||
RandomWalkBase random_walk_;
|
||||
|
||||
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
|
||||
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
|
||||
|
|
|
@ -39,17 +39,25 @@ Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
|
|||
}
|
||||
}
|
||||
|
||||
Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) {
|
||||
Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors, bool exclude_itself) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
neighbors.resize(itr->second.size() + 1);
|
||||
neighbors[0] = id_;
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
if (exclude_itself) {
|
||||
neighbors.resize(itr->second.size());
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin(),
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
} else {
|
||||
neighbors.resize(itr->second.size() + 1);
|
||||
neighbors[0] = id_;
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
|
||||
neighbors.emplace_back(id_);
|
||||
if (!exclude_itself) {
|
||||
neighbors.emplace_back(id_);
|
||||
}
|
||||
}
|
||||
*out_neighbors = std::move(neighbors);
|
||||
return Status::OK();
|
||||
|
|
|
@ -47,7 +47,8 @@ class LocalNode : public Node {
|
|||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) override;
|
||||
Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
|
||||
bool exclude_itself = false) override;
|
||||
|
||||
// Get the sampled neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
|
|
|
@ -56,7 +56,8 @@ class Node {
|
|||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
virtual Status GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType> *out_neighbors,
|
||||
bool exclude_itself = false) = 0;
|
||||
|
||||
// Get the sampled neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore._c_dataengine import Tensor
|
|||
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_all_neighbors, check_gnn_get_sampled_neighbors, \
|
||||
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature
|
||||
check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, check_gnn_random_walk
|
||||
|
||||
|
||||
class GraphData:
|
||||
|
@ -148,7 +148,8 @@ class GraphData:
|
|||
TypeError: If `neighbor_nums` is not list or ndarray.
|
||||
TypeError: If `neighbor_types` is not list or ndarray.
|
||||
"""
|
||||
return self._graph.get_sampled_neighbors(node_list, neighbor_nums, neighbor_types).as_array()
|
||||
return self._graph.get_sampled_neighbors(
|
||||
node_list, neighbor_nums, neighbor_types).as_array()
|
||||
|
||||
@check_gnn_get_neg_sampled_neighbors
|
||||
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
|
||||
|
@ -174,7 +175,8 @@ class GraphData:
|
|||
TypeError: If `neg_neighbor_num` is not integer.
|
||||
TypeError: If `neg_neighbor_type` is not integer.
|
||||
"""
|
||||
return self._graph.get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type).as_array()
|
||||
return self._graph.get_neg_sampled_neighbors(
|
||||
node_list, neg_neighbor_num, neg_neighbor_type).as_array()
|
||||
|
||||
@check_gnn_get_node_feature
|
||||
def get_node_feature(self, node_list, feature_types):
|
||||
|
@ -200,7 +202,10 @@ class GraphData:
|
|||
"""
|
||||
if isinstance(node_list, list):
|
||||
node_list = np.array(node_list, dtype=np.int32)
|
||||
return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]
|
||||
return [
|
||||
t.as_array() for t in self._graph.get_node_feature(
|
||||
Tensor(node_list),
|
||||
feature_types)]
|
||||
|
||||
def graph_info(self):
|
||||
"""
|
||||
|
@ -212,3 +217,36 @@ class GraphData:
|
|||
node_feature_type and edge_feature_type.
|
||||
"""
|
||||
return self._graph.graph_info()
|
||||
|
||||
@check_gnn_random_walk
|
||||
def random_walk(
|
||||
self,
|
||||
target_nodes,
|
||||
meta_path,
|
||||
step_home_param=1.0,
|
||||
step_away_param=1.0,
|
||||
default_node=-1):
|
||||
"""
|
||||
Random walk in nodes.
|
||||
|
||||
Args:
|
||||
target_nodes (list[int]): Start node list in random walk
|
||||
meta_path (list[int]): node type for each walk step
|
||||
step_home_param (float): return hyper parameter in node2vec algorithm
|
||||
step_away_param (float): inout hyper parameter in node2vec algorithm
|
||||
default_node (int): default node if no more neighbors found
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.random_walk([1,2], [1,2,1,2,1])
|
||||
|
||||
Raises:
|
||||
TypeError: If `target_nodes` is not list or ndarray.
|
||||
TypeError: If `meta_path` is not list or ndarray.
|
||||
"""
|
||||
return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
|
||||
default_node).as_array()
|
||||
|
|
|
@ -1299,6 +1299,24 @@ def check_gnn_get_neg_sampled_neighbors(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_gnn_random_walk(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `random_walk` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
check_gnn_list_or_ndarray(param_dict.get("target_nodes"), 'target_nodes')
|
||||
|
||||
# check meta_path; required argument
|
||||
check_gnn_list_or_ndarray(param_dict.get("meta_path"), 'meta_path')
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_aligned_list(param, param_name, membor_type):
|
||||
"""Check whether the structure of each member of the list is the same."""
|
||||
|
||||
|
|
|
@ -27,6 +27,13 @@
|
|||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::gnn;
|
||||
|
||||
#define print_int_vec(_i, _str) \
|
||||
do { \
|
||||
std::stringstream ss; \
|
||||
std::copy(_i.begin(), _i.end(), std::ostream_iterator<int>(ss, " ")); \
|
||||
MS_LOG(INFO) << _str << " " << ss.str(); \
|
||||
} while (false)
|
||||
|
||||
class MindDataTestGNNGraph : public UT::Common {
|
||||
protected:
|
||||
MindDataTestGNNGraph() = default;
|
||||
|
@ -195,3 +202,29 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
|||
s = graph.GetNegSampledNeighbors(node_list, 3, 3, &neg_neighbors);
|
||||
EXPECT_TRUE(s.ToString().find("Invalid node type:3") != std::string::npos);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
||||
std::string path = "data/mindrecord/testGraphData/sns";
|
||||
Graph graph(path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
MetaInfo meta_info;
|
||||
s = graph.GetMetaInfo(&meta_info);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
std::shared_ptr<Tensor> nodes;
|
||||
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::vector<NodeIdType> node_list;
|
||||
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
|
||||
node_list.push_back(*itr);
|
||||
}
|
||||
|
||||
print_int_vec(node_list, "node list ");
|
||||
std::vector<NodeType> meta_path(59, 1);
|
||||
std::shared_ptr<Tensor> walk_path;
|
||||
s = graph.RandomWalk(node_list, meta_path, 2.0, 0.5, -1, &walk_path);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(walk_path->shape().ToString() == "<33,60>");
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -19,6 +19,7 @@ import mindspore.dataset as ds
|
|||
from mindspore import log as logger
|
||||
|
||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
|
||||
|
||||
|
||||
def test_graphdata_getfullneighbor():
|
||||
|
@ -172,6 +173,17 @@ def test_graphdata_generatordataset():
|
|||
assert i == 40
|
||||
|
||||
|
||||
def test_graphdata_randomwalk():
|
||||
g = ds.GraphData(SOCIAL_DATA_FILE, 1)
|
||||
nodes = g.get_all_nodes(1)
|
||||
print(len(nodes))
|
||||
assert len(nodes) == 33
|
||||
|
||||
meta_path = [1 for _ in range(39)]
|
||||
walks = g.random_walk(nodes, meta_path)
|
||||
assert walks.shape == (33, 40)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_graphdata_getfullneighbor()
|
||||
logger.info('test_graphdata_getfullneighbor Ended.\n')
|
||||
|
@ -185,3 +197,5 @@ if __name__ == '__main__':
|
|||
logger.info('test_graphdata_graphinfo Ended.\n')
|
||||
test_graphdata_generatordataset()
|
||||
logger.info('test_graphdata_generatordataset Ended.\n')
|
||||
test_graphdata_randomwalk()
|
||||
logger.info('test_graphdata_randomwalk Ended.\n')
|
||||
|
|
Loading…
Reference in New Issue