Optimize the negative sampling process

This commit is contained in:
heleiwang 2020-08-31 17:32:02 +08:00
parent 18253952f5
commit 857cf2f77f
3 changed files with 30 additions and 14 deletions

View File

@ -211,22 +211,22 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
return Status::OK();
}
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data,
const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
std::vector<NodeIdType> *out_samples) {
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data,
int32_t samples_num, std::vector<NodeIdType> *out_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
std::vector<NodeIdType> shuffled_id(data.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
for (const auto &index : shuffled_id) {
if (exclude_data.find(data[index]) != exclude_data.end()) {
size_t index = *start_index;
for (size_t i = index; i < shuffled_ids.size(); ++i) {
++index;
if (exclude_data.find(data[shuffled_ids[i]]) != exclude_data.end()) {
continue;
}
out_samples->emplace_back(data[index]);
out_samples->emplace_back(data[shuffled_ids[i]]);
if (out_samples->size() >= samples_num) {
break;
}
}
*start_index = index;
return Status::OK();
}
@ -236,6 +236,13 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
std::vector<NodeIdType> shuffled_id(all_nodes.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
size_t start_index = 0;
bool need_shuffle = false;
std::vector<std::vector<NodeIdType>> neg_neighbors_vec;
neg_neighbors_vec.resize(node_list.size());
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) {
@ -247,12 +254,15 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
std::transform(neighbors.begin(), neighbors.end(),
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()),
[](const NodeIdType node) { return node; });
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type];
neg_neighbors_vec[node_idx].emplace_back(node->id());
if (all_nodes.size() > exclude_nodes.size()) {
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) {
RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(),
RETURN_IF_NOT_OK(NegativeSample(all_nodes, shuffled_id, &start_index, exclude_nodes, samples_num + 1,
&neg_neighbors_vec[node_idx]));
if (start_index >= shuffled_id.size()) {
start_index = start_index % shuffled_id.size();
need_shuffle = true;
}
}
} else {
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id()
@ -262,6 +272,11 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId);
}
}
if (need_shuffle) {
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
start_index = 0;
need_shuffle = false;
}
}
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out));
return Status::OK();

View File

@ -251,8 +251,9 @@ class GraphDataImpl : public GraphData {
// @param int32_t samples_num -
// @param std::vector<NodeIdType> *out_samples - Sampling results returned
// @return Status - The error code return
Status NegativeSample(const std::vector<NodeIdType> &input_data, const std::unordered_set<NodeIdType> &exclude_data,
int32_t samples_num, std::vector<NodeIdType> *out_samples);
Status NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids,
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
std::vector<NodeIdType> *out_samples);
Status CheckSamplesNum(NodeIdType samples_num);

View File

@ -87,7 +87,7 @@ def test_graphdata_distributed():
p1 = Process(target=graphdata_startserver, args=(server_port,))
p1.start()
time.sleep(2)
time.sleep(5)
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
nodes = g.get_all_nodes(1)