forked from mindspore-Ecosystem/mindspore
!5624 Optimize the performance of GraphData.get_neg_sampled_neighbors
Merge pull request !5624 from heleiwang/gnn_perf
This commit is contained in:
commit
8f3ebfd469
|
@ -211,22 +211,22 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data,
|
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
|
||||||
const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
|
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data,
|
||||||
std::vector<NodeIdType> *out_samples) {
|
int32_t samples_num, std::vector<NodeIdType> *out_samples) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
|
||||||
std::vector<NodeIdType> shuffled_id(data.size());
|
size_t index = *start_index;
|
||||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
for (size_t i = index; i < shuffled_ids.size(); ++i) {
|
||||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
++index;
|
||||||
for (const auto &index : shuffled_id) {
|
if (exclude_data.find(data[shuffled_ids[i]]) != exclude_data.end()) {
|
||||||
if (exclude_data.find(data[index]) != exclude_data.end()) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
out_samples->emplace_back(data[index]);
|
out_samples->emplace_back(data[shuffled_ids[i]]);
|
||||||
if (out_samples->size() >= samples_num) {
|
if (out_samples->size() >= samples_num) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*start_index = index;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,6 +236,13 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
|
||||||
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
|
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
|
||||||
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
|
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
|
||||||
|
|
||||||
|
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
|
||||||
|
std::vector<NodeIdType> shuffled_id(all_nodes.size());
|
||||||
|
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||||
|
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||||
|
size_t start_index = 0;
|
||||||
|
bool need_shuffle = false;
|
||||||
|
|
||||||
std::vector<std::vector<NodeIdType>> neg_neighbors_vec;
|
std::vector<std::vector<NodeIdType>> neg_neighbors_vec;
|
||||||
neg_neighbors_vec.resize(node_list.size());
|
neg_neighbors_vec.resize(node_list.size());
|
||||||
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
|
||||||
|
@ -247,12 +254,15 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
|
||||||
std::transform(neighbors.begin(), neighbors.end(),
|
std::transform(neighbors.begin(), neighbors.end(),
|
||||||
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()),
|
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()),
|
||||||
[](const NodeIdType node) { return node; });
|
[](const NodeIdType node) { return node; });
|
||||||
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
|
|
||||||
neg_neighbors_vec[node_idx].emplace_back(node->id());
|
neg_neighbors_vec[node_idx].emplace_back(node->id());
|
||||||
if (all_nodes.size() > exclude_nodes.size()) {
|
if (all_nodes.size() > exclude_nodes.size()) {
|
||||||
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) {
|
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) {
|
||||||
RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(),
|
RETURN_IF_NOT_OK(NegativeSample(all_nodes, shuffled_id, &start_index, exclude_nodes, samples_num + 1,
|
||||||
&neg_neighbors_vec[node_idx]));
|
&neg_neighbors_vec[node_idx]));
|
||||||
|
if (start_index >= shuffled_id.size()) {
|
||||||
|
start_index = start_index % shuffled_id.size();
|
||||||
|
need_shuffle = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
|
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
|
||||||
|
@ -262,6 +272,11 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
|
||||||
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
|
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (need_shuffle) {
|
||||||
|
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||||
|
start_index = 0;
|
||||||
|
need_shuffle = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out));
|
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -251,8 +251,9 @@ class GraphDataImpl : public GraphData {
|
||||||
// @param int32_t samples_num -
|
// @param int32_t samples_num -
|
||||||
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
|
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
|
Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
|
||||||
int32_t samples_num, std::vector<NodeIdType> *out_samples);
|
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
|
||||||
|
std::vector<NodeIdType> *out_samples);
|
||||||
|
|
||||||
Status CheckSamplesNum(NodeIdType samples_num);
|
Status CheckSamplesNum(NodeIdType samples_num);
|
||||||
|
|
||||||
|
|
|
@ -87,7 +87,7 @@ def test_graphdata_distributed():
|
||||||
|
|
||||||
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
||||||
p1.start()
|
p1.start()
|
||||||
time.sleep(2)
|
time.sleep(5)
|
||||||
|
|
||||||
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
||||||
nodes = g.get_all_nodes(1)
|
nodes = g.get_all_nodes(1)
|
||||||
|
|
Loading…
Reference in New Issue