150 lines
6.3 KiB
Python
150 lines
6.3 KiB
Python
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import os
|
|
import random
|
|
import time
|
|
from multiprocessing import Process
|
|
import numpy as np
|
|
import mindspore.dataset as ds
|
|
from mindspore import log as logger
|
|
from mindspore.dataset.engine import SamplingStrategy
|
|
from mindspore.dataset.engine import OutputFormat
|
|
|
|
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
|
|
|
|
|
def graphdata_startserver(server_port):
|
|
"""
|
|
start graphdata server
|
|
"""
|
|
logger.info('test start server.\n')
|
|
ds.GraphData(DATASET_FILE, 1, 'server', port=server_port)
|
|
|
|
|
|
class RandomBatchedSampler(ds.Sampler):
|
|
# RandomBatchedSampler generate random sequence without replacement in a batched manner
|
|
def __init__(self, index_range, num_edges_per_sample):
|
|
super().__init__()
|
|
self.index_range = index_range
|
|
self.num_edges_per_sample = num_edges_per_sample
|
|
|
|
def __iter__(self):
|
|
indices = [i+1 for i in range(self.index_range)]
|
|
# Reset random seed here if necessary
|
|
# random.seed(0)
|
|
random.shuffle(indices)
|
|
for i in range(0, self.index_range, self.num_edges_per_sample):
|
|
# Drop reminder
|
|
if i + self.num_edges_per_sample <= self.index_range:
|
|
yield indices[i: i + self.num_edges_per_sample]
|
|
|
|
|
|
class GNNGraphDataset():
|
|
def __init__(self, g, batch_num):
|
|
self.g = g
|
|
self.batch_num = batch_num
|
|
|
|
def __len__(self):
|
|
# Total sample size of GNN dataset
|
|
# In this case, the size should be total_num_edges/num_edges_per_sample
|
|
return self.g.graph_info()['edge_num'][0] // self.batch_num
|
|
|
|
def __getitem__(self, index):
|
|
# index will be a list of indices yielded from RandomBatchedSampler
|
|
# Fetch edges/nodes/samples/features based on indices
|
|
nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
|
|
nodes = nodes[:, 0]
|
|
neg_nodes = self.g.get_neg_sampled_neighbors(
|
|
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
|
|
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
|
|
2, 2], neighbor_types=[2, 1], strategy=SamplingStrategy.RANDOM)
|
|
neg_nodes_neighbors = self.g.get_sampled_neighbors(node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2],
|
|
neighbor_types=[2, 1], strategy=SamplingStrategy.EDGE_WEIGHT)
|
|
nodes_neighbors_features = self.g.get_node_feature(
|
|
node_list=nodes_neighbors, feature_types=[2, 3])
|
|
neg_neighbors_features = self.g.get_node_feature(
|
|
node_list=neg_nodes_neighbors, feature_types=[2, 3])
|
|
return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
|
|
|
|
|
|
def test_graphdata_distributed():
|
|
"""
|
|
Feature: GraphData
|
|
Description: Test GraphData distributed
|
|
Expectation: Output is equal to the expected output
|
|
"""
|
|
ASAN = os.environ.get('ASAN_OPTIONS')
|
|
if ASAN:
|
|
logger.info("skip the graphdata distributed when asan mode")
|
|
return
|
|
|
|
logger.info('test distributed.\n')
|
|
|
|
server_port = random.randint(10000, 60000)
|
|
|
|
p1 = Process(target=graphdata_startserver, args=(server_port,))
|
|
p1.start()
|
|
time.sleep(5)
|
|
|
|
g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port)
|
|
nodes = g.get_all_nodes(1)
|
|
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
|
|
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
|
|
assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
|
|
[1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
|
|
[0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
|
|
assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
|
|
|
|
neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL)
|
|
assert neighbor_normal.shape == (10, 6)
|
|
neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
|
|
assert neighbor_coo.shape == (20, 2)
|
|
offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
|
|
assert offset_table.shape == (10,)
|
|
assert neighbor_csr.shape == (20,)
|
|
|
|
edges = g.get_all_edges(0)
|
|
assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
|
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
|
|
features = g.get_edge_feature(edges, [1, 2])
|
|
assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
|
|
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
|
|
|
|
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (202, 102), (201, 107), (208, 108)]
|
|
edges = g.get_edges_from_nodes(nodes_pair_list)
|
|
assert edges.tolist() == [1, 9, 31, 17, 20, 25, 34, 37]
|
|
|
|
batch_num = 2
|
|
edge_num = g.graph_info()['edge_num'][0]
|
|
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
|
|
dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
|
|
sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
|
|
python_multiprocessing=False)
|
|
dataset = dataset.repeat(2)
|
|
itr = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
|
|
i = 0
|
|
for data in itr:
|
|
assert data['neighbors'].shape == (2, 7)
|
|
assert data['neg_neighbors'].shape == (6, 7)
|
|
assert data['neighbors_features'].shape == (2, 7)
|
|
assert data['neg_neighbors_features'].shape == (6, 7)
|
|
i += 1
|
|
assert i == 40
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_graphdata_distributed()
|