!5624 Optimize the performance of GraphData.get_neg_sampled_neighbors

Merge pull request !5624 from heleiwang/gnn_perf
This commit is contained in:
mindspore-ci-bot 2020-09-01 19:22:53 +08:00 committed by Gitee
commit 8f3ebfd469
3 changed files with 30 additions and 14 deletions

View File

@ -211,22 +211,22 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
return Status::OK(); return Status::OK();
} }
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data,
std::vector<NodeIdType> *out_samples) { int32_t samples_num, std::vector<NodeIdType> *out_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
std::vector<NodeIdType> shuffled_id(data.size()); size_t index = *start_index;
std::iota(shuffled_id.begin(), shuffled_id.end(), 0); for (size_t i = index; i < shuffled_ids.size(); ++i) {
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); ++index;
for (const auto &index : shuffled_id) { if (exclude_data.find(data[shuffled_ids[i]]) != exclude_data.end()) {
if (exclude_data.find(data[index]) != exclude_data.end()) {
continue; continue;
} }
out_samples->emplace_back(data[index]); out_samples->emplace_back(data[shuffled_ids[i]]);
if (out_samples->size() >= samples_num) { if (out_samples->size() >= samples_num) {
break; break;
} }
} }
*start_index = index;
return Status::OK(); return Status::OK();
} }
@ -236,6 +236,13 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
std::vector<NodeIdType> shuffled_id(all_nodes.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
size_t start_index = 0;
bool need_shuffle = false;
std::vector<std::vector<NodeIdType>> neg_neighbors_vec; std::vector<std::vector<NodeIdType>> neg_neighbors_vec;
neg_neighbors_vec.resize(node_list.size()); neg_neighbors_vec.resize(node_list.size());
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
@ -247,12 +254,15 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
std::transform(neighbors.begin(), neighbors.end(), std::transform(neighbors.begin(), neighbors.end(),
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()), std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()),
[](const NodeIdType node) { return node; }); [](const NodeIdType node) { return node; });
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
neg_neighbors_vec[node_idx].emplace_back(node->id()); neg_neighbors_vec[node_idx].emplace_back(node->id());
if (all_nodes.size() > exclude_nodes.size()) { if (all_nodes.size() > exclude_nodes.size()) {
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { while (neg_neighbors_vec[node_idx].size() < samples_num + 1) {
RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), RETURN_IF_NOT_OK(NegativeSample(all_nodes, shuffled_id, &start_index, exclude_nodes, samples_num + 1,
&neg_neighbors_vec[node_idx])); &neg_neighbors_vec[node_idx]));
if (start_index >= shuffled_id.size()) {
start_index = start_index % shuffled_id.size();
need_shuffle = true;
}
} }
} else { } else {
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
@ -262,6 +272,11 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
} }
} }
if (need_shuffle) {
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
start_index = 0;
need_shuffle = false;
}
} }
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out));
return Status::OK(); return Status::OK();

View File

@ -251,8 +251,9 @@ class GraphDataImpl : public GraphData {
// @param int32_t samples_num - // @param int32_t samples_num -
// @param std::vector<NodeIdType> *out_samples - Sampling results returned // @param std::vector<NodeIdType> *out_samples - Sampling results returned
// @return Status - The error code return // @return Status - The error code return
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data, Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
int32_t samples_num, std::vector<NodeIdType> *out_samples); size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
std::vector<NodeIdType> *out_samples);
Status CheckSamplesNum(NodeIdType samples_num); Status CheckSamplesNum(NodeIdType samples_num);

View File

@ -87,7 +87,7 @@ def test_graphdata_distributed():
p1 = Process(target=graphdata_startserver, args=(server_port,)) p1 = Process(target=graphdata_startserver, args=(server_port,))
p1.start() p1.start()
time.sleep(2) time.sleep(5)
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port) g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
nodes = g.get_all_nodes(1) nodes = g.get_all_nodes(1)