forked from mindspore-Ecosystem/mindspore
!5624 Optimize the performance of GraphData.get_neg_sampled_neighbors
Merge pull request !5624 from heleiwang/gnn_perf
This commit is contained in:
commit
8f3ebfd469
|
@ -211,22 +211,22 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data,
|
||||
const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_samples) {
|
||||
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
|
||||
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data,
|
||||
int32_t samples_num, std::vector<NodeIdType> *out_samples) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
|
||||
std::vector<NodeIdType> shuffled_id(data.size());
|
||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
for (const auto &index : shuffled_id) {
|
||||
if (exclude_data.find(data[index]) != exclude_data.end()) {
|
||||
size_t index = *start_index;
|
||||
for (size_t i = index; i < shuffled_ids.size(); ++i) {
|
||||
++index;
|
||||
if (exclude_data.find(data[shuffled_ids[i]]) != exclude_data.end()) {
|
||||
continue;
|
||||
}
|
||||
out_samples->emplace_back(data[index]);
|
||||
out_samples->emplace_back(data[shuffled_ids[i]]);
|
||||
if (out_samples->size() >= samples_num) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
*start_index = index;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -236,6 +236,13 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
|
|||
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
|
||||
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
|
||||
|
||||
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
|
||||
std::vector<NodeIdType> shuffled_id(all_nodes.size());
|
||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
size_t start_index = 0;
|
||||
bool need_shuffle = false;
|
||||
|
||||
std::vector<std::vector<NodeIdType>> neg_neighbors_vec;
|
||||
neg_neighbors_vec.resize(node_list.size());
|
||||
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
||||
|
@ -247,12 +254,15 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
|
|||
std::transform(neighbors.begin(), neighbors.end(),
|
||||
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()),
|
||||
[](const NodeIdType node) { return node; });
|
||||
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
|
||||
neg_neighbors_vec[node_idx].emplace_back(node->id());
|
||||
if (all_nodes.size() > exclude_nodes.size()) {
|
||||
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) {
|
||||
RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(),
|
||||
RETURN_IF_NOT_OK(NegativeSample(all_nodes, shuffled_id, &start_index, exclude_nodes, samples_num + 1,
|
||||
&neg_neighbors_vec[node_idx]));
|
||||
if (start_index >= shuffled_id.size()) {
|
||||
start_index = start_index % shuffled_id.size();
|
||||
need_shuffle = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
|
||||
|
@ -262,6 +272,11 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
|
|||
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
|
||||
}
|
||||
}
|
||||
if (need_shuffle) {
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
start_index = 0;
|
||||
need_shuffle = false;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out));
|
||||
return Status::OK();
|
||||
|
|
|
@ -251,8 +251,9 @@ class GraphDataImpl : public GraphData {
|
|||
// @param int32_t samples_num -
|
||||
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
|
||||
// @return Status - The error code return
|
||||
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
|
||||
int32_t samples_num, std::vector<NodeIdType> *out_samples);
|
||||
Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
|
||||
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out_samples);
|
||||
|
||||
Status CheckSamplesNum(NodeIdType samples_num);
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ def test_graphdata_distributed():
|
|||
|
||||
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
||||
p1.start()
|
||||
time.sleep(2)
|
||||
time.sleep(5)
|
||||
|
||||
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
||||
nodes = g.get_all_nodes(1)
|
||||
|
|
Loading…
Reference in New Issue