diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc index a37e92ed4ea..b97cad0ffaa 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc @@ -211,22 +211,22 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector &node_li return Status::OK(); } -Status GraphDataImpl::NegativeSample(const std::vector &data, - const std::unordered_set &exclude_data, int32_t samples_num, - std::vector *out_samples) { +Status GraphDataImpl::NegativeSample(const std::vector &data, const std::vector shuffled_ids, + size_t *start_index, const std::unordered_set &exclude_data, + int32_t samples_num, std::vector *out_samples) { CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); - std::vector shuffled_id(data.size()); - std::iota(shuffled_id.begin(), shuffled_id.end(), 0); - std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); - for (const auto &index : shuffled_id) { - if (exclude_data.find(data[index]) != exclude_data.end()) { + size_t index = *start_index; + for (size_t i = index; i < shuffled_ids.size(); ++i) { + ++index; + if (exclude_data.find(data[shuffled_ids[i]]) != exclude_data.end()) { continue; } - out_samples->emplace_back(data[index]); + out_samples->emplace_back(data[shuffled_ids[i]]); if (out_samples->size() >= samples_num) { break; } } + *start_index = index; return Status::OK(); } @@ -236,6 +236,13 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector &node RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); + const std::vector &all_nodes = node_type_map_[neg_neighbor_type]; + std::vector shuffled_id(all_nodes.size()); + std::iota(shuffled_id.begin(), shuffled_id.end(), 0); + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + size_t start_index = 0; + bool need_shuffle = false; + std::vector> neg_neighbors_vec; neg_neighbors_vec.resize(node_list.size()); for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { @@ -247,12 +254,15 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector &node std::transform(neighbors.begin(), neighbors.end(), std::insert_iterator>(exclude_nodes, exclude_nodes.begin()), [](const NodeIdType node) { return node; }); - const std::vector &all_nodes = node_type_map_[neg_neighbor_type]; neg_neighbors_vec[node_idx].emplace_back(node->id()); if (all_nodes.size() > exclude_nodes.size()) { while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { - RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), + RETURN_IF_NOT_OK(NegativeSample(all_nodes, shuffled_id, &start_index, exclude_nodes, samples_num + 1, &neg_neighbors_vec[node_idx])); + if (start_index >= shuffled_id.size()) { + start_index = start_index % shuffled_id.size(); + need_shuffle = true; + } } } else { MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() @@ -262,6 +272,11 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector &node neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); } } + if (need_shuffle) { + std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); + start_index = 0; + need_shuffle = false; + } } RETURN_IF_NOT_OK(CreateTensorByVector(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h index d596e99a2d9..5a6853cf8ba 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h @@ -251,8 +251,9 @@ class GraphDataImpl : public GraphData { // @param int32_t samples_num - // @param std::vector *out_samples - Sampling results returned // @return Status - The error code return - Status NegativeSample(const std::vector &input_data, const std::unordered_set &exclude_data, - int32_t samples_num, std::vector *out_samples); + Status NegativeSample(const std::vector &data, const std::vector shuffled_ids, + size_t *start_index, const std::unordered_set &exclude_data, int32_t samples_num, + std::vector *out_samples); Status CheckSamplesNum(NodeIdType samples_num); diff --git a/tests/ut/python/dataset/test_graphdata_distributed.py b/tests/ut/python/dataset/test_graphdata_distributed.py index 9d70cb57b6b..9762b3e8f77 100644 --- a/tests/ut/python/dataset/test_graphdata_distributed.py +++ b/tests/ut/python/dataset/test_graphdata_distributed.py @@ -87,7 +87,7 @@ def test_graphdata_distributed(): p1 = Process(target=graphdata_startserver, args=(server_port,)) p1.start() - time.sleep(2) + time.sleep(5) g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port) nodes = g.get_all_nodes(1)