forked from mindspore-Ecosystem/mindspore
Support processing GNN data
This commit is contained in:
parent
d6e930d737
commit
599a449e0b
|
@ -0,0 +1,73 @@
|
|||
# Guideline to Efficiently Generating MindRecord
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
- [What does the example do](#what-does-the-example-do)
|
||||
- [Example test for Cora](#example-test-for-cora)
|
||||
- [How to use the example for other dataset](#how-to-use-the-example-for-other-dataset)
|
||||
- [Create work space](#create-work-space)
|
||||
- [Implement data generator](#implement-data-generator)
|
||||
- [Run data generator](#run-data-generator)
|
||||
|
||||
|
||||
<!-- /TOC -->
|
||||
|
||||
## What does the example do
|
||||
|
||||
This example provides an efficient way to generate MindRecord. Users only need to define the parallel granularity of training data reading and the data reading function of a single task. That is, they can efficiently convert the user's training data into MindRecord.
|
||||
|
||||
1. write_cora.sh: entry script, users need to modify parameters according to their own training data.
|
||||
2. writer.py: main script, called by write_cora.sh, it mainly reads user training data in parallel and generates MindRecord.
|
||||
3. cora/mr_api.py: uers define their own parallel granularity of training data reading and single task reading function through the cora.
|
||||
|
||||
## Example test for Cora
|
||||
|
||||
1. Download and prepare the Cora dataset as required.
|
||||
|
||||
> [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora)
|
||||
|
||||
|
||||
2. Edit write_cora.sh and modify the parameters
|
||||
```
|
||||
--mindrecord_file: output MindRecord file.
|
||||
--mindrecord_partitions: the partitions for MindRecord.
|
||||
```
|
||||
|
||||
3. Run the bash script
|
||||
```bash
|
||||
bash write_cora.sh
|
||||
```
|
||||
|
||||
## How to use the example for other dataset
|
||||
|
||||
### Create work space
|
||||
|
||||
Assume the dataset name is 'xyz'
|
||||
* Create work space from cora
|
||||
```shell
|
||||
cd ${your_mindspore_home}/example/graph_to_mindrecord
|
||||
cp -r cora xyz
|
||||
```
|
||||
|
||||
### Implement data generator
|
||||
|
||||
Edit dictionary data generator.
|
||||
* Edit file
|
||||
```shell
|
||||
cd ${your_mindspore_home}/example/graph_to_mindrecord
|
||||
vi xyz/mr_api.py
|
||||
```
|
||||
|
||||
Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented.
|
||||
- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks.
|
||||
- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number()
|
||||
|
||||
### Run data generator
|
||||
|
||||
* run python script
|
||||
```shell
|
||||
cd ${your_mindspore_home}/example/graph_to_mindrecord
|
||||
python writer.py --mindrecord_script xyz [...]
|
||||
```
|
||||
> You can put this command in script **write_xyz.sh** for easy execution
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import csv
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
CITESEER_CONTENT_FILE = args[0]
|
||||
CITESEER_CITES_FILE = args[1]
|
||||
CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord"
|
||||
CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord"
|
||||
|
||||
node_id_map = {}
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def _normalize_citeseer_features(features):
|
||||
features = np.array(features)
|
||||
row_sum = np.array(features.sum(1))
|
||||
r_inv = np.power(row_sum * 1.0, -1).flatten()
|
||||
r_inv[np.isinf(r_inv)] = 0.
|
||||
r_mat_inv = sp.diags(r_inv)
|
||||
features = r_mat_inv.dot(features)
|
||||
return features
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
label_types = {}
|
||||
label_size = 0
|
||||
node_num = 0
|
||||
with open(CITESEER_CONTENT_FILE) as content_file:
|
||||
content_reader = csv.reader(content_file, delimiter='\t')
|
||||
line_count = 0
|
||||
for row in content_reader:
|
||||
if not row[-1] in label_types:
|
||||
label_types[row[-1]] = label_size
|
||||
label_size += 1
|
||||
if not row[0] in node_id_map:
|
||||
node_id_map[row[0]] = node_num
|
||||
node_num += 1
|
||||
raw_features = [[int(x) for x in row[1:-1]]]
|
||||
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features),
|
||||
'feature_2': [label_types[row[-1]]]}
|
||||
yield node
|
||||
line_count += 1
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
# print('label types {}.'.format(label_types))
|
||||
with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f:
|
||||
for k in label_types:
|
||||
print(k + ',' + str(label_types[k]), file=f)
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
"""
|
||||
Generate edge data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
# print(map_string_int)
|
||||
with open(CITESEER_CITES_FILE) as cites_file:
|
||||
cites_reader = csv.reader(cites_file, delimiter='\t')
|
||||
line_count = 0
|
||||
for row in cites_reader:
|
||||
if not row[0] in node_id_map:
|
||||
print('Source node {} does not exist.'.format(row[0]))
|
||||
continue
|
||||
if not row[1] in node_id_map:
|
||||
print('Destination node {} does not exist.'.format(row[1]))
|
||||
continue
|
||||
line_count += 1
|
||||
edge = {'id': line_count,
|
||||
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
|
||||
yield edge
|
||||
|
||||
with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f:
|
||||
for k in node_id_map:
|
||||
print(k + ',' + str(node_id_map[k]), file=f)
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
User-defined API for MindRecord GNN writer.
|
||||
"""
|
||||
import csv
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
# parse args from command line parameter 'graph_api_args'
|
||||
# args delimiter is ':'
|
||||
args = os.environ['graph_api_args'].split(':')
|
||||
CORA_CONTENT_FILE = args[0]
|
||||
CORA_CITES_FILE = args[1]
|
||||
CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord"
|
||||
CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord"
|
||||
|
||||
node_id_map = {}
|
||||
|
||||
# profile: (num_features, feature_data_types, feature_shapes)
|
||||
node_profile = (2, ["float32", "int64"], [[-1], [-1]])
|
||||
edge_profile = (0, [], [])
|
||||
|
||||
|
||||
def _normalize_cora_features(features):
|
||||
features = np.array(features)
|
||||
row_sum = np.array(features.sum(1))
|
||||
r_inv = np.power(row_sum * 1.0, -1).flatten()
|
||||
r_inv[np.isinf(r_inv)] = 0.
|
||||
r_mat_inv = sp.diags(r_inv)
|
||||
features = r_mat_inv.dot(features)
|
||||
return features
|
||||
|
||||
|
||||
def yield_nodes(task_id=0):
|
||||
"""
|
||||
Generate node data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Node task is {}".format(task_id))
|
||||
label_types = {}
|
||||
label_size = 0
|
||||
node_num = 0
|
||||
with open(CORA_CONTENT_FILE) as content_file:
|
||||
content_reader = csv.reader(content_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in content_reader:
|
||||
if line_count == 0:
|
||||
line_count += 1
|
||||
continue
|
||||
if not row[0] in node_id_map:
|
||||
node_id_map[row[0]] = node_num
|
||||
node_num += 1
|
||||
if not row[-1] in label_types:
|
||||
label_types[row[-1]] = label_size
|
||||
label_size += 1
|
||||
raw_features = [[int(x) for x in row[1:-1]]]
|
||||
node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features),
|
||||
'feature_2': [label_types[row[-1]]]}
|
||||
yield node
|
||||
line_count += 1
|
||||
print('Processed {} lines for nodes.'.format(line_count))
|
||||
print('label types {}.'.format(label_types))
|
||||
with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f:
|
||||
for k in label_types:
|
||||
print(k + ',' + str(label_types[k]), file=f)
|
||||
|
||||
|
||||
def yield_edges(task_id=0):
|
||||
"""
|
||||
Generate edge data
|
||||
|
||||
Yields:
|
||||
data (dict): data row which is dict.
|
||||
"""
|
||||
print("Edge task is {}".format(task_id))
|
||||
with open(CORA_CITES_FILE) as cites_file:
|
||||
cites_reader = csv.reader(cites_file, delimiter=',')
|
||||
line_count = 0
|
||||
for row in cites_reader:
|
||||
if line_count == 0:
|
||||
line_count += 1
|
||||
continue
|
||||
if not row[0] in node_id_map:
|
||||
print('Source node {} does not exist.'.format(row[0]))
|
||||
continue
|
||||
if not row[1] in node_id_map:
|
||||
print('Destination node {} does not exist.'.format(row[1]))
|
||||
continue
|
||||
edge = {'id': line_count,
|
||||
'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0}
|
||||
yield edge
|
||||
line_count += 1
|
||||
print('Processed {} lines for edges.'.format(line_count))
|
||||
with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f:
|
||||
for k in node_id_map:
|
||||
print(k + ',' + str(node_id_map[k]), file=f)
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
python reader.py --path "/tmp/citeseer/mindrecord/citeseer_mr"
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
python reader.py --path "/tmp/cora/mindrecord/cora_mr"
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
######################## write mindrecord example ########################
|
||||
Write mindrecord by data dictionary:
|
||||
python writer.py --mindrecord_script /YourScriptPath ...
|
||||
"""
|
||||
import argparse
|
||||
|
||||
import mindspore.dataset as ds
|
||||
|
||||
parser = argparse.ArgumentParser(description='Mind record reader')
|
||||
parser.add_argument('--path', type=str, default="/tmp/cora/mindrecord/cora_mr",
|
||||
help='data file')
|
||||
args = parser.parse_args()
|
||||
|
||||
data_set = ds.MindDataset(args.path)
|
||||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
print(item)
|
||||
num_iter += 1
|
||||
print("Total items # is {}".format(num_iter))
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
rm /tmp/citeseer/mindrecord/*
|
||||
|
||||
python writer.py --mindrecord_script citeseer \
|
||||
--mindrecord_file "/tmp/citeseer/mindrecord/citeseer_mr" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "/tmp/citeseer/dataset/citeseer.content:/tmp/citeseer/dataset/citeseer.cites"
|
|
@ -0,0 +1,9 @@
|
|||
#!/bin/bash
|
||||
rm /tmp/cora/mindrecord/*
|
||||
|
||||
python writer.py --mindrecord_script cora \
|
||||
--mindrecord_file "/tmp/cora/mindrecord/cora_mr" \
|
||||
--mindrecord_partitions 1 \
|
||||
--mindrecord_header_size_by_bit 18 \
|
||||
--mindrecord_page_size_by_bit 20 \
|
||||
--graph_api_args "/tmp/cora/dataset/cora_content.csv:/tmp/cora/dataset/cora_cites.csv"
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
######################## write mindrecord example ########################
|
||||
Write mindrecord by data dictionary:
|
||||
python writer.py --mindrecord_script /YourScriptPath ...
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from importlib import import_module
|
||||
from multiprocessing import Pool
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from mindspore.mindrecord import GraphMapSchema
|
||||
|
||||
|
||||
def exec_task(task_id, parallel_writer=True):
|
||||
"""
|
||||
Execute task with specified task id
|
||||
"""
|
||||
print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
|
||||
imagenet_iter = mindrecord_dict_data(task_id)
|
||||
batch_size = 512
|
||||
transform_count = 0
|
||||
while True:
|
||||
data_list = []
|
||||
try:
|
||||
for _ in range(batch_size):
|
||||
data = imagenet_iter.__next__()
|
||||
if 'dst_id' in data:
|
||||
data = graph_map_schema.transform_edge(data)
|
||||
else:
|
||||
data = graph_map_schema.transform_node(data)
|
||||
data_list.append(data)
|
||||
transform_count += 1
|
||||
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
|
||||
print("transformed {} record...".format(transform_count))
|
||||
except StopIteration:
|
||||
if data_list:
|
||||
writer.write_raw_data(data_list, parallel_writer=parallel_writer)
|
||||
print("transformed {} record...".format(transform_count))
|
||||
break
|
||||
|
||||
|
||||
def read_args():
|
||||
"""
|
||||
read args
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Mind record writer')
|
||||
parser.add_argument('--mindrecord_script', type=str, default="template",
|
||||
help='path where script is saved')
|
||||
|
||||
parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord/xyz",
|
||||
help='written file name prefix')
|
||||
|
||||
parser.add_argument('--mindrecord_partitions', type=int, default=1,
|
||||
help='number of written files')
|
||||
|
||||
parser.add_argument('--mindrecord_header_size_by_bit', type=int, default=24,
|
||||
help='mindrecord file header size')
|
||||
|
||||
parser.add_argument('--mindrecord_page_size_by_bit', type=int, default=25,
|
||||
help='mindrecord file page size')
|
||||
|
||||
parser.add_argument('--mindrecord_workers', type=int, default=8,
|
||||
help='number of parallel workers')
|
||||
|
||||
parser.add_argument('--num_node_tasks', type=int, default=1,
|
||||
help='number of node tasks')
|
||||
|
||||
parser.add_argument('--num_edge_tasks', type=int, default=1,
|
||||
help='number of node tasks')
|
||||
|
||||
parser.add_argument('--graph_api_args', type=str, default="/tmp/nodes.csv:/tmp/edges.csv",
|
||||
help='nodes and edges data file, csv format with header.')
|
||||
|
||||
ret_args = parser.parse_args()
|
||||
|
||||
return ret_args
|
||||
|
||||
|
||||
def init_writer(mr_schema):
|
||||
"""
|
||||
init writer
|
||||
"""
|
||||
print("Init writer ...")
|
||||
mr_writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
|
||||
|
||||
# set the header size
|
||||
if args.mindrecord_header_size_by_bit != 24:
|
||||
header_size = 1 << args.mindrecord_header_size_by_bit
|
||||
mr_writer.set_header_size(header_size)
|
||||
|
||||
# set the page size
|
||||
if args.mindrecord_page_size_by_bit != 25:
|
||||
page_size = 1 << args.mindrecord_page_size_by_bit
|
||||
mr_writer.set_page_size(page_size)
|
||||
|
||||
# create the schema
|
||||
mr_writer.add_schema(mr_schema, "mindrecord_graph_schema")
|
||||
|
||||
# open file and set header
|
||||
mr_writer.open_and_set_header()
|
||||
|
||||
return mr_writer
|
||||
|
||||
|
||||
def run_parallel_workers(num_tasks):
|
||||
"""
|
||||
run parallel workers
|
||||
"""
|
||||
# set number of workers
|
||||
num_workers = args.mindrecord_workers
|
||||
|
||||
task_list = list(range(num_tasks))
|
||||
|
||||
if num_workers > num_tasks:
|
||||
num_workers = num_tasks
|
||||
|
||||
if os.name == 'nt':
|
||||
for window_task_id in task_list:
|
||||
exec_task(window_task_id, False)
|
||||
elif num_tasks > 1:
|
||||
with Pool(num_workers) as p:
|
||||
p.map(exec_task, task_list)
|
||||
else:
|
||||
exec_task(0, False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = read_args()
|
||||
print(args)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# pass mr_api arguments
|
||||
os.environ['graph_api_args'] = args.graph_api_args
|
||||
|
||||
# import mr_api
|
||||
try:
|
||||
mr_api = import_module(args.mindrecord_script + '.mr_api')
|
||||
except ModuleNotFoundError:
|
||||
raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
|
||||
|
||||
# init graph schema
|
||||
graph_map_schema = GraphMapSchema()
|
||||
|
||||
num_features, feature_data_types, feature_shapes = mr_api.node_profile
|
||||
graph_map_schema.set_node_feature_profile(num_features, feature_data_types, feature_shapes)
|
||||
|
||||
num_features, feature_data_types, feature_shapes = mr_api.edge_profile
|
||||
graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
|
||||
|
||||
graph_schema = graph_map_schema.get_schema()
|
||||
|
||||
# init writer
|
||||
writer = init_writer(graph_schema)
|
||||
|
||||
# write nodes data
|
||||
mindrecord_dict_data = mr_api.yield_nodes
|
||||
run_parallel_workers(args.num_node_tasks)
|
||||
|
||||
# write edges data
|
||||
mindrecord_dict_data = mr_api.yield_edges
|
||||
run_parallel_workers(args.num_edge_tasks)
|
||||
|
||||
# writer wrap up
|
||||
ret = writer.commit()
|
||||
|
||||
end_time = time.time()
|
||||
print("--------------------------------------------")
|
||||
print("END. Total time: {}".format(end_time - start_time))
|
||||
print("--------------------------------------------")
|
|
@ -66,6 +66,7 @@ set(submodules
|
|||
$<TARGET_OBJECTS:APItoPython>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source>
|
||||
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
|
||||
$<TARGET_OBJECTS:engine-gnn>
|
||||
$<TARGET_OBJECTS:engine-datasetops>
|
||||
$<TARGET_OBJECTS:engine-opt>
|
||||
$<TARGET_OBJECTS:engine>
|
||||
|
|
|
@ -61,6 +61,7 @@
|
|||
#include "dataset/engine/jagged_connector.h"
|
||||
#include "dataset/engine/datasetops/source/text_file_op.h"
|
||||
#include "dataset/engine/datasetops/source/voc_op.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
#include "dataset/kernels/data/to_float16_op.h"
|
||||
#include "dataset/util/random.h"
|
||||
#include "mindrecord/include/shard_operator.h"
|
||||
|
@ -513,6 +514,33 @@ void bindVocabObjects(py::module *m) {
|
|||
});
|
||||
}
|
||||
|
||||
void bindGraphData(py::module *m) {
|
||||
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
|
||||
.def(py::init([](std::string dataset_file, int32_t num_workers) {
|
||||
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
|
||||
THROW_IF_ERROR(g_out->Init());
|
||||
return g_out;
|
||||
}))
|
||||
.def("get_nodes",
|
||||
[](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_all_neighbors",
|
||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_node_feature",
|
||||
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
||||
TensorRow out;
|
||||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||
return out;
|
||||
});
|
||||
}
|
||||
|
||||
// This is where we externalize the C logic as python modules
|
||||
PYBIND11_MODULE(_c_dataengine, m) {
|
||||
m.doc() = "pybind11 for _c_dataengine";
|
||||
|
@ -578,6 +606,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
|
|||
bindDatasetOps(&m);
|
||||
bindInfoObjects(&m);
|
||||
bindVocabObjects(&m);
|
||||
bindGraphData(&m);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
add_subdirectory(datasetops)
|
||||
add_subdirectory(opt)
|
||||
add_subdirectory(gnn)
|
||||
if (ENABLE_TDTQUE)
|
||||
add_subdirectory(tdt)
|
||||
endif ()
|
||||
|
@ -15,7 +16,7 @@ add_library(engine OBJECT
|
|||
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
|
||||
|
||||
if (ENABLE_TDTQUE)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn)
|
||||
else()
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt)
|
||||
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn)
|
||||
endif ()
|
||||
|
|
|
@ -112,10 +112,10 @@ Status MindRecordOp::Init() {
|
|||
|
||||
data_schema_ = std::make_unique<DataSchema>();
|
||||
|
||||
std::vector<std::string> col_names = shard_reader_->get_shard_column()->GetColumnName();
|
||||
std::vector<std::string> col_names = shard_reader_->GetShardColumn()->GetColumnName();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found");
|
||||
std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->get_shard_column()->GeColumnDataType();
|
||||
std::vector<std::vector<int64_t>> col_shapes = shard_reader_->get_shard_column()->GetColumnShape();
|
||||
std::vector<mindrecord::ColumnDataType> col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType();
|
||||
std::vector<std::vector<int64_t>> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape();
|
||||
|
||||
bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything
|
||||
std::map<std::string, int32_t> colname_to_ind;
|
||||
|
@ -296,8 +296,7 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
|
|||
std::vector<int64_t> column_shape;
|
||||
|
||||
// Get column data
|
||||
|
||||
auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName(
|
||||
auto has_column = shard_reader_->GetShardColumn()->GetColumnValueByName(
|
||||
column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size,
|
||||
&column_shape);
|
||||
if (has_column == MSRStatus::FAILED) {
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
add_library(engine-gnn OBJECT
|
||||
graph.cc
|
||||
graph_loader.cc
|
||||
local_node.cc
|
||||
local_edge.cc
|
||||
feature.cc
|
||||
)
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_EDGE_H_
|
||||
#define DATASET_ENGINE_GNN_EDGE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
#include "dataset/engine/gnn/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
using EdgeType = int8_t;
|
||||
using EdgeIdType = int32_t;
|
||||
|
||||
class Edge {
|
||||
public:
|
||||
// Constructor
|
||||
// @param EdgeIdType id - edge id
|
||||
// @param EdgeType type - edge type
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
Edge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
|
||||
: id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {}
|
||||
|
||||
virtual ~Edge() = default;
|
||||
|
||||
// @return NodeIdType - Returned edge id
|
||||
EdgeIdType id() const { return id_; }
|
||||
|
||||
// @return NodeIdType - Returned edge type
|
||||
EdgeType type() const { return type_; }
|
||||
|
||||
// Get the feature of a edge
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
|
||||
|
||||
// Get nodes on the edge
|
||||
// @param std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> *out_node - Source and destination nodes returned
|
||||
Status GetNode(std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> *out_node) {
|
||||
*out_node = std::make_pair(src_node_, dst_node_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Set node to edge
|
||||
// @param const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> &in_node -
|
||||
Status SetNode(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> &in_node) {
|
||||
src_node_ = in_node.first;
|
||||
dst_node_ = in_node.second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Update feature of edge
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
|
||||
|
||||
protected:
|
||||
EdgeIdType id_;
|
||||
EdgeType type_;
|
||||
std::shared_ptr<Node> src_node_;
|
||||
std::shared_ptr<Node> dst_node_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_EDGE_H_
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_FEATURE_H_
|
||||
#define DATASET_ENGINE_GNN_FEATURE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
using FeatureType = int16_t;
|
||||
|
||||
class Feature {
|
||||
public:
|
||||
// Constructor
|
||||
// @param FeatureType type_name - feature type
|
||||
// @param std::shared_ptr<Tensor> value - feature value
|
||||
Feature(FeatureType type_name, std::shared_ptr<Tensor> value);
|
||||
|
||||
// Get feature value
|
||||
// @return std::shared_ptr<Tensor> *out_value - feature value
|
||||
const std::shared_ptr<Tensor> Value() const { return value_; }
|
||||
|
||||
// @return NodeIdType - Returned feature type
|
||||
FeatureType type() const { return type_name_; }
|
||||
|
||||
private:
|
||||
FeatureType type_name_;
|
||||
std::shared_ptr<Tensor> value_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_FEATURE_H_
|
|
@ -0,0 +1,251 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/core/tensor_shape.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
Graph::Graph(std::string dataset_file, int32_t num_workers) : dataset_file_(dataset_file), num_workers_(num_workers) {
|
||||
MS_LOG(INFO) << "num_workers:" << num_workers;
|
||||
}
|
||||
|
||||
Status Graph::GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out) {
|
||||
auto itr = node_type_map_.find(node_type);
|
||||
if (itr == node_type_map_.end()) {
|
||||
std::string err_msg = "Invalid node type:" + std::to_string(node_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
if (node_num == -1) {
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>({itr->second}, DataType(DataType::DE_INT32), out));
|
||||
} else {
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
if (!type.IsCompatible<T>()) {
|
||||
RETURN_STATUS_UNEXPECTED("Data type not compatible");
|
||||
}
|
||||
if (data.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Input data is emply");
|
||||
}
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
size_t m = data.size();
|
||||
size_t n = data[0].size();
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(
|
||||
&tensor, TensorImpl::kFlexible, TensorShape({static_cast<dsize_t>(m), static_cast<dsize_t>(n)}), type, nullptr));
|
||||
T *ptr = reinterpret_cast<T *>(tensor->GetMutableBuffer());
|
||||
for (auto id_m : data) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size");
|
||||
for (auto id_n : id_m) {
|
||||
*ptr = id_n;
|
||||
ptr++;
|
||||
}
|
||||
}
|
||||
tensor->Squeeze();
|
||||
*out = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) {
|
||||
if (!data || data->empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("Input data is emply");
|
||||
}
|
||||
for (std::vector<T> &vec : *data) {
|
||||
size_t size = vec.size();
|
||||
if (size > max_size) {
|
||||
RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal");
|
||||
} else {
|
||||
for (size_t i = 0; i < (max_size - size); ++i) {
|
||||
vec.push_back(default_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out) { return Status::OK(); }
|
||||
|
||||
Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
|
||||
std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
|
||||
std::vector<std::vector<NodeIdType>> neighbors;
|
||||
size_t max_neighbor_num = 0;
|
||||
neighbors.resize(node_list.size());
|
||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||
auto itr = node_id_map_.find(node_list[i]);
|
||||
if (itr != node_id_map_.end()) {
|
||||
RETURN_IF_NOT_OK(itr->second->GetNeighbors(neighbor_type, -1, &neighbors[i]));
|
||||
max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size();
|
||||
} else {
|
||||
std::string err_msg = "Invalid node id:" + std::to_string(node_list[i]);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId));
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors, DataType(DataType::DE_INT32), out));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p,
|
||||
float q, NodeIdType default_node, std::shared_ptr<Tensor> *out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = default_feature_map_.find(feature_type);
|
||||
if (itr == default_feature_map_.end()) {
|
||||
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
*out_feature = itr->second;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) {
|
||||
if (!nodes || nodes->Size() == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("Inpude nodes is empty");
|
||||
}
|
||||
TensorRow tensors;
|
||||
for (auto f_type : feature_types) {
|
||||
std::shared_ptr<Feature> default_feature;
|
||||
// If no feature can be obtained, fill in the default value
|
||||
RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature));
|
||||
|
||||
TensorShape shape(default_feature->Value()->shape());
|
||||
auto shape_vec = nodes->shape().AsVector();
|
||||
dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
|
||||
shape = shape.PrependDim(size);
|
||||
std::shared_ptr<Tensor> fea_tensor;
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr));
|
||||
|
||||
dsize_t index = 0;
|
||||
for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
|
||||
auto itr = node_id_map_.find(*node_itr);
|
||||
std::shared_ptr<Feature> feature;
|
||||
if (itr != node_id_map_.end()) {
|
||||
if (!itr->second->GetFeatures(f_type, &feature).IsOk()) {
|
||||
feature = default_feature;
|
||||
}
|
||||
} else {
|
||||
if (*node_itr == kDefaultNodeId) {
|
||||
feature = default_feature;
|
||||
} else {
|
||||
std::string err_msg = "Invalid node id:" + std::to_string(*node_itr);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value()));
|
||||
index++;
|
||||
}
|
||||
|
||||
TensorShape reshape(nodes->shape());
|
||||
for (auto s : default_feature->Value()->shape().AsVector()) {
|
||||
reshape = reshape.AppendDim(s);
|
||||
}
|
||||
RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
|
||||
fea_tensor->Squeeze();
|
||||
tensors.push_back(fea_tensor);
|
||||
}
|
||||
*out = std::move(tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::Init() {
|
||||
RETURN_IF_NOT_OK(LoadNodeAndEdge());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info) {
|
||||
node_info->reserve(node_type_map_.size());
|
||||
for (auto node : node_type_map_) {
|
||||
NodeMetaInfo n_info;
|
||||
n_info.type = node.first;
|
||||
n_info.num = node.second.size();
|
||||
auto itr = node_feature_map_.find(node.first);
|
||||
if (itr != node_feature_map_.end()) {
|
||||
for (auto f_type : itr->second) {
|
||||
n_info.feature_type.push_back(f_type);
|
||||
}
|
||||
std::sort(n_info.feature_type.begin(), n_info.feature_type.end());
|
||||
}
|
||||
node_info->push_back(n_info);
|
||||
}
|
||||
|
||||
edge_info->reserve(edge_type_map_.size());
|
||||
for (auto edge : edge_type_map_) {
|
||||
EdgeMetaInfo e_info;
|
||||
e_info.type = edge.first;
|
||||
e_info.num = edge.second.size();
|
||||
auto itr = edge_feature_map_.find(edge.first);
|
||||
if (itr != edge_feature_map_.end()) {
|
||||
for (auto f_type : itr->second) {
|
||||
e_info.feature_type.push_back(f_type);
|
||||
}
|
||||
}
|
||||
edge_info->push_back(e_info);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Graph::LoadNodeAndEdge() {
|
||||
GraphLoader gl(dataset_file_, num_workers_);
|
||||
// ask graph_loader to load everything into memory
|
||||
RETURN_IF_NOT_OK(gl.InitAndLoad());
|
||||
// get all maps
|
||||
RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_,
|
||||
&node_feature_map_, &edge_feature_map_, &default_feature_map_));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,166 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_GRAPH_H_
|
||||
#define DATASET_ENGINE_GNN_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/gnn/graph_loader.h"
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
#include "dataset/engine/gnn/node.h"
|
||||
#include "dataset/engine/gnn/edge.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
struct NodeMetaInfo {
|
||||
NodeType type;
|
||||
NodeIdType num;
|
||||
std::vector<FeatureType> feature_type;
|
||||
NodeMetaInfo() {
|
||||
type = 0;
|
||||
num = 0;
|
||||
}
|
||||
};
|
||||
|
||||
struct EdgeMetaInfo {
|
||||
EdgeType type;
|
||||
EdgeIdType num;
|
||||
std::vector<FeatureType> feature_type;
|
||||
EdgeMetaInfo() {
|
||||
type = 0;
|
||||
num = 0;
|
||||
}
|
||||
};
|
||||
|
||||
class Graph {
|
||||
public:
|
||||
// Constructor
|
||||
// @param std::string dataset_file -
|
||||
// @param int32_t num_workers - number of parallel threads
|
||||
Graph(std::string dataset_file, int32_t num_workers);
|
||||
|
||||
~Graph() = default;
|
||||
|
||||
// Get the nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
// @return Status - The error code return
|
||||
Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get the edges from the graph.
|
||||
// @param NodeType edge_type - type of edge
|
||||
// @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired
|
||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||
// @return Status - The error code return
|
||||
Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status - The error code return
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out);
|
||||
|
||||
Status GetSampledNeighbor(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
|
||||
Status GetNegSampledNeighbor(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
|
||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, float p, float q,
|
||||
NodeIdType default_node, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status - The error code return
|
||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out);
|
||||
|
||||
// Get the feature of a edge
|
||||
// @param std::shared_ptr<Tensor> edget - List of edges
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status - The error code return
|
||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out);
|
||||
|
||||
// Get meta information of graph
|
||||
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of node
|
||||
// @param std::vector<NodeMetaInfo> *node_info - Returned meta information of edge
|
||||
// @return Status - The error code return
|
||||
Status GetMetaInfo(std::vector<NodeMetaInfo> *node_info, std::vector<EdgeMetaInfo> *edge_info);
|
||||
|
||||
Status Init();
|
||||
|
||||
private:
|
||||
// Load graph data from mindrecord file
|
||||
// @return Status - The error code return
|
||||
Status LoadNodeAndEdge();
|
||||
|
||||
// Create Tensor By Vector
|
||||
// @param std::vector<std::vector<T>> &data -
|
||||
// @param DataType type -
|
||||
// @param std::shared_ptr<Tensor> *out -
|
||||
// @return Status - The error code return
|
||||
template <typename T>
|
||||
Status CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Complete vector
|
||||
// @param std::vector<std::vector<T>> *data - To be completed vector
|
||||
// @param size_t max_size - The size of the completed vector
|
||||
// @param T default_value - Filled default
|
||||
// @return Status - The error code return
|
||||
template <typename T>
|
||||
Status ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value);
|
||||
|
||||
// Get the default feature of a node
|
||||
// @param FeatureType feature_type -
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature);
|
||||
|
||||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
|
||||
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
|
||||
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
|
||||
|
||||
std::unordered_map<EdgeType, std::vector<EdgeIdType>> edge_type_map_;
|
||||
std::unordered_map<EdgeIdType, std::shared_ptr<Edge>> edge_id_map_;
|
||||
|
||||
std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_;
|
||||
std::unordered_map<NodeType, std::unordered_set<FeatureType>> edge_feature_map_;
|
||||
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_feature_map_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_GRAPH_H_
|
|
@ -0,0 +1,248 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <future>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/engine/gnn/graph_loader.h"
|
||||
#include "mindspore/ccsrc/mindrecord/include/shard_error.h"
|
||||
#include "dataset/engine/gnn/local_edge.h"
|
||||
#include "dataset/engine/gnn/local_node.h"
|
||||
|
||||
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
using mindrecord::MSRStatus;
|
||||
|
||||
GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
|
||||
: mr_path_(mr_filepath),
|
||||
num_workers_(num_workers),
|
||||
row_id_(0),
|
||||
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
||||
|
||||
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
|
||||
EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
|
||||
EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) {
|
||||
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
|
||||
while (dq.empty() == false) {
|
||||
std::shared_ptr<Node> node_ptr = dq.front();
|
||||
n_id_map->insert({node_ptr->id(), node_ptr});
|
||||
(*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
|
||||
dq.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
for (std::deque<std::shared_ptr<Edge>> &dq : e_deques_) {
|
||||
while (dq.empty() == false) {
|
||||
std::shared_ptr<Edge> edge_ptr = dq.front();
|
||||
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
|
||||
RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
|
||||
auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first));
|
||||
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
|
||||
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
|
||||
(*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
|
||||
dq.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
|
||||
for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
|
||||
|
||||
MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoader::InitAndLoad() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n");
|
||||
n_deques_.resize(num_workers_);
|
||||
e_deques_.resize(num_workers_);
|
||||
n_feature_maps_.resize(num_workers_);
|
||||
e_feature_maps_.resize(num_workers_);
|
||||
default_feature_maps_.resize(num_workers_);
|
||||
std::vector<std::future<Status>> r_codes(num_workers_);
|
||||
|
||||
shard_reader_ = std::make_unique<ShardReader>();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS,
|
||||
"Fail to open" + mr_path_);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
|
||||
|
||||
mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
|
||||
for (const std::string &key : keys_) {
|
||||
if (schema.find(key) == schema.end()) {
|
||||
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
|
||||
}
|
||||
}
|
||||
|
||||
// launching worker threads
|
||||
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
||||
r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id);
|
||||
}
|
||||
// wait for threads to finish and check its return code
|
||||
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
||||
RETURN_IF_NOT_OK(r_codes[wkr_id].get());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
|
||||
std::shared_ptr<Node> *node, NodeFeatureMap *feature_map,
|
||||
DefaultFeatureMap *default_feature) {
|
||||
NodeIdType node_id = col_jsn["first_id"];
|
||||
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
|
||||
(*node) = std::make_shared<LocalNode>(node_id, node_type);
|
||||
std::vector<int32_t> indices;
|
||||
RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
|
||||
|
||||
for (int32_t ind : indices) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
|
||||
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
|
||||
(*feature_map)[node_type].insert(ind);
|
||||
if ((*default_feature)[ind] == nullptr) {
|
||||
std::shared_ptr<Tensor> zero_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
|
||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrecord::json &col_jsn,
|
||||
std::shared_ptr<Edge> *edge, EdgeFeatureMap *feature_map,
|
||||
DefaultFeatureMap *default_feature) {
|
||||
EdgeIdType edge_id = col_jsn["first_id"];
|
||||
EdgeType edge_type = static_cast<EdgeType>(col_jsn["type"]);
|
||||
NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"];
|
||||
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1);
|
||||
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
|
||||
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
|
||||
std::vector<int32_t> indices;
|
||||
RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
|
||||
for (int32_t ind : indices) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
|
||||
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
|
||||
(*feature_map)[edge_type].insert(ind);
|
||||
if ((*default_feature)[ind] == nullptr) {
|
||||
std::shared_ptr<Tensor> zero_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type()));
|
||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||
const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0, col_type_size = 1;
|
||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||
std::vector<int64_t> column_shape;
|
||||
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
|
||||
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
|
||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible,
|
||||
std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
|
||||
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||
const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
|
||||
const unsigned char *data = nullptr;
|
||||
std::unique_ptr<unsigned char[]> data_ptr;
|
||||
uint64_t n_bytes = 0, col_type_size = 1;
|
||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||
std::vector<int64_t> column_shape;
|
||||
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
|
||||
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
|
||||
|
||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||
|
||||
for (int i = 0; i < n_bytes; i += col_type_size) {
|
||||
int32_t feature_ind = -1;
|
||||
if (col_type == mindrecord::ColumnInt32) {
|
||||
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
|
||||
} else if (col_type == mindrecord::ColumnInt64) {
|
||||
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
|
||||
}
|
||||
if (feature_ind >= 0) indices->push_back(feature_ind);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoader::WorkerEntry(int32_t worker_id) {
|
||||
ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id);
|
||||
while (rows.empty() == false) {
|
||||
for (const auto &tupled_row : rows) {
|
||||
std::vector<uint8_t> col_blob = std::get<0>(tupled_row);
|
||||
mindrecord::json col_jsn = std::get<1>(tupled_row);
|
||||
std::string attr = col_jsn["attribute"];
|
||||
if (attr == "n") {
|
||||
std::shared_ptr<Node> node_ptr;
|
||||
RETURN_IF_NOT_OK(
|
||||
LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
|
||||
n_deques_[worker_id].emplace_back(node_ptr);
|
||||
} else if (attr == "e") {
|
||||
std::shared_ptr<Edge> edge_ptr;
|
||||
RETURN_IF_NOT_OK(
|
||||
LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id]));
|
||||
e_deques_[worker_id].emplace_back(edge_ptr);
|
||||
} else {
|
||||
MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node.";
|
||||
}
|
||||
}
|
||||
rows = shard_reader_->GetNextById(row_id_++, worker_id);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
|
||||
DefaultFeatureMap *default_feature_map) {
|
||||
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||
for (auto &m : n_feature_maps_[wkr_id]) {
|
||||
for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
|
||||
}
|
||||
for (auto &m : e_feature_maps_[wkr_id]) {
|
||||
for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
|
||||
}
|
||||
for (auto &m : default_feature_maps_[wkr_id]) {
|
||||
(*default_feature_map)[m.first] = m.second;
|
||||
}
|
||||
}
|
||||
n_feature_maps_.clear();
|
||||
e_feature_maps_.clear();
|
||||
}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_
|
||||
#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
#include "dataset/engine/gnn/graph.h"
|
||||
#include "dataset/engine/gnn/node.h"
|
||||
#include "dataset/engine/gnn/edge.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "mindrecord/include/shard_reader.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
using mindrecord::ShardReader;
|
||||
using NodeIdMap = std::unordered_map<NodeIdType, std::shared_ptr<Node>>;
|
||||
using EdgeIdMap = std::unordered_map<EdgeIdType, std::shared_ptr<Edge>>;
|
||||
using NodeTypeMap = std::unordered_map<NodeType, std::vector<NodeIdType>>;
|
||||
using EdgeTypeMap = std::unordered_map<EdgeType, std::vector<EdgeIdType>>;
|
||||
using NodeFeatureMap = std::unordered_map<NodeType, std::unordered_set<FeatureType>>;
|
||||
using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureType>>;
|
||||
using DefaultFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
|
||||
|
||||
// this class interfaces with the underlying storage format (mindrecord)
|
||||
// it returns raw nodes and edges via GetNodesAndEdges
|
||||
// it is then the responsibility of graph to construct itself based on the nodes and edges
|
||||
// if needed, this class could become a base where each derived class handles a specific storage format
|
||||
class GraphLoader {
|
||||
public:
|
||||
explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4);
|
||||
|
||||
// Init mindrecord and load everything into memory multi-threaded
|
||||
// @return Status - the status code
|
||||
Status InitAndLoad();
|
||||
|
||||
// this function will query mindrecord and construct all nodes and edges
|
||||
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
|
||||
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
|
||||
// features attached to each node and edge are expected to be filled correctly
|
||||
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *,
|
||||
DefaultFeatureMap *);
|
||||
|
||||
private:
|
||||
//
|
||||
// worker thread that reads mindrecord file
|
||||
// @param int32_t worker_id - id of each worker
|
||||
// @return Status - the status code
|
||||
Status WorkerEntry(int32_t worker_id);
|
||||
|
||||
// Load a node based on 1 row of mindrecord, returns a shared_ptr<Node>
|
||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||
// @param mindrecord::json &jsn - contains raw data
|
||||
// @param std::shared_ptr<Node> *node - return value
|
||||
// @param NodeFeatureMap *feature_map -
|
||||
// @param DefaultFeatureMap *default_feature -
|
||||
// @return Status - the status code
|
||||
Status LoadNode(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Node> *node,
|
||||
NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature);
|
||||
|
||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||
// @param mindrecord::json &jsn - contains raw data
|
||||
// @param std::shared_ptr<Edge> *edge - return value, the edge ptr, edge is not yet connected
|
||||
// @param FeatureMap *feature_map
|
||||
// @param DefaultFeatureMap *default_feature -
|
||||
// @return Status - the status code
|
||||
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
|
||||
EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature);
|
||||
|
||||
// @param std::string key - column name
|
||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||
// @param mindrecord::json &jsn - contains raw data
|
||||
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
|
||||
// @return Status - the status code
|
||||
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
|
||||
std::vector<int32_t> *ind);
|
||||
|
||||
// @param std::string &key - column name
|
||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||
// @param mindrecord::json &jsn - contains raw data
|
||||
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
|
||||
// @return Status - the status code
|
||||
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
|
||||
std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
|
||||
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *);
|
||||
|
||||
const int32_t num_workers_;
|
||||
std::atomic_int row_id_;
|
||||
std::string mr_path_;
|
||||
std::unique_ptr<ShardReader> shard_reader_;
|
||||
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
|
||||
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
|
||||
std::vector<NodeFeatureMap> n_feature_maps_;
|
||||
std::vector<EdgeFeatureMap> e_feature_maps_;
|
||||
std::vector<DefaultFeatureMap> default_feature_maps_;
|
||||
const std::vector<std::string> keys_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/gnn/local_edge.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
|
||||
: Edge(id, type, src_node, dst_node) {}
|
||||
|
||||
Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
if (itr != features_.end()) {
|
||||
*out_feature = itr->second;
|
||||
return Status::OK();
|
||||
} else {
|
||||
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
|
||||
auto itr = features_.find(feature->type());
|
||||
if (itr != features_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Feature already exists");
|
||||
} else {
|
||||
features_[feature->type()] = feature;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_
|
||||
#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/engine/gnn/edge.h"
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
#include "dataset/engine/gnn/node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
class LocalEdge : public Edge {
|
||||
public:
|
||||
// Constructor
|
||||
// @param EdgeIdType id - edge id
|
||||
// @param EdgeType type - edge type
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node);
|
||||
|
||||
~LocalEdge() = default;
|
||||
|
||||
// Get the feature of a edge
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
|
||||
|
||||
// Update feature of edge
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/engine/gnn/local_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "dataset/engine/gnn/edge.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {}
|
||||
|
||||
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
if (itr != features_.end()) {
|
||||
*out_feature = itr->second;
|
||||
return Status::OK();
|
||||
} else {
|
||||
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
if (samples_num == -1) {
|
||||
// Return all neighbors
|
||||
neighbors.resize(itr->second.size() + 1);
|
||||
neighbors[0] = id_;
|
||||
std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1,
|
||||
[](const std::shared_ptr<Node> node) { return node->id(); });
|
||||
} else {
|
||||
}
|
||||
} else {
|
||||
neighbors.push_back(id_);
|
||||
MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type;
|
||||
}
|
||||
*out_neighbors = std::move(neighbors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::AddNeighbor(const std::shared_ptr<Node> &node) {
|
||||
auto itr = neighbor_nodes_.find(node->type());
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
itr->second.push_back(node);
|
||||
} else {
|
||||
neighbor_nodes_[node->type()] = {node};
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) {
|
||||
auto itr = features_.find(feature->type());
|
||||
if (itr != features_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Feature already exists");
|
||||
} else {
|
||||
features_[feature->type()] = feature;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_
|
||||
#define DATASET_ENGINE_GNN_LOCAL_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/engine/gnn/node.h"
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
class LocalNode : public Node {
|
||||
public:
|
||||
// Constructor
|
||||
// @param NodeIdType id - node id
|
||||
// @param NodeType type - node type
|
||||
LocalNode(NodeIdType id, NodeType type);
|
||||
|
||||
~LocalNode() = default;
|
||||
|
||||
// Get the feature of a node
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) override;
|
||||
|
||||
// Get the neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) override;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @return Status - The error code return
|
||||
Status AddNeighbor(const std::shared_ptr<Node> &node) override;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
std::unordered_map<NodeType, std::vector<std::shared_ptr<Node>>> neighbor_nodes_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_
|
|
@ -0,0 +1,79 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef DATASET_ENGINE_GNN_NODE_H_
|
||||
#define DATASET_ENGINE_GNN_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/engine/gnn/feature.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
using NodeType = int8_t;
|
||||
using NodeIdType = int32_t;
|
||||
|
||||
constexpr NodeIdType kDefaultNodeId = -1;
|
||||
|
||||
class Node {
|
||||
public:
|
||||
// Constructor
|
||||
// @param NodeIdType id - node id
|
||||
// @param NodeType type - node type
|
||||
Node(NodeIdType id, NodeType type) : id_(id), type_(type) {}
|
||||
|
||||
virtual ~Node() = default;
|
||||
|
||||
// @return NodeIdType - Returned node id
|
||||
NodeIdType id() const { return id_; }
|
||||
|
||||
// @return NodeIdType - Returned node type
|
||||
NodeType type() const { return type_; }
|
||||
|
||||
// Get the feature of a node
|
||||
// @param FeatureType feature_type - type of feature
|
||||
// @param std::shared_ptr<Feature> *out_feature - Returned feature
|
||||
// @return Status - The error code return
|
||||
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
|
||||
|
||||
// Get the neighbors of a node
|
||||
// @param NodeType neighbor_type - type of neighbor
|
||||
// @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired
|
||||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status - The error code return
|
||||
virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @return Status - The error code return
|
||||
virtual Status AddNeighbor(const std::shared_ptr<Node> &node) = 0;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @return Status - The error code return
|
||||
virtual Status UpdateFeature(const std::shared_ptr<Feature> &feature) = 0;
|
||||
|
||||
protected:
|
||||
NodeIdType id_;
|
||||
NodeType type_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // DATASET_ENGINE_GNN_NODE_H_
|
|
@ -114,7 +114,7 @@ class ShardReader {
|
|||
|
||||
/// \brief aim to get columns context
|
||||
/// \return the columns
|
||||
std::shared_ptr<ShardColumn> get_shard_column() const;
|
||||
std::shared_ptr<ShardColumn> GetShardColumn() const;
|
||||
|
||||
/// \brief get the number of shards
|
||||
/// \return # of shards
|
||||
|
|
|
@ -232,7 +232,7 @@ void ShardReader::Close() {
|
|||
|
||||
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
|
||||
|
||||
std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; }
|
||||
std::shared_ptr<ShardColumn> ShardReader::GetShardColumn() const { return shard_column_; }
|
||||
|
||||
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
|
||||
|
||||
|
|
|
@ -25,9 +25,10 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset
|
|||
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
|
||||
WeightedRandomSampler, Sampler
|
||||
from .engine.serializer_deserializer import serialize, deserialize, show
|
||||
from .engine.graphdata import GraphData
|
||||
|
||||
__all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset",
|
||||
"MindDataset", "GeneratorDataset", "TFRecordDataset",
|
||||
"ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset",
|
||||
"VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler",
|
||||
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"]
|
||||
"SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"]
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
graphdata.py supports loading graph dataset for GNN network training,
|
||||
and provides operations related to graph data.
|
||||
"""
|
||||
import numpy as np
|
||||
from mindspore._c_dataengine import Graph
|
||||
from mindspore._c_dataengine import Tensor
|
||||
|
||||
from .validators import check_gnn_get_all_nodes, check_gnn_get_all_neighbors, check_gnn_get_node_feature
|
||||
|
||||
|
||||
class GraphData:
|
||||
"""
|
||||
Reads th graph dataset used for GNN training from the shared file and database.
|
||||
|
||||
Args:
|
||||
dataset_file (str): One of file names in dataset.
|
||||
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
|
||||
(default=None).
|
||||
"""
|
||||
|
||||
def __init__(self, dataset_file, num_parallel_workers=None):
|
||||
self._dataset_file = dataset_file
|
||||
if num_parallel_workers is None:
|
||||
num_parallel_workers = 1
|
||||
self._graph = Graph(dataset_file, num_parallel_workers)
|
||||
|
||||
@check_gnn_get_all_nodes
|
||||
def get_all_nodes(self, node_type):
|
||||
"""
|
||||
Get all nodes in the graph.
|
||||
|
||||
Args:
|
||||
node_type (int): Specify the tpye of node.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.get_all_nodes(0)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_type` is not integer.
|
||||
"""
|
||||
return self._graph.get_nodes(node_type, -1).as_array()
|
||||
|
||||
@check_gnn_get_all_neighbors
|
||||
def get_all_neighbors(self, node_list, neighbor_type):
|
||||
"""
|
||||
Get `neighbor_type` neighbors of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
node_list (list or numpy.ndarray): The given list of nodes.
|
||||
neighbor_type (int): Specify the tpye of neighbor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.get_all_nodes(0)
|
||||
>>> neighbors = data_graph.get_all_neighbors(nodes[0], 0)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `neighbor_type` is not integer.
|
||||
"""
|
||||
return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
|
||||
|
||||
@check_gnn_get_node_feature
|
||||
def get_node_feature(self, node_list, feature_types):
|
||||
"""
|
||||
Get `feature_types` feature of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
node_list (list or numpy.ndarray): The given list of nodes.
|
||||
feature_types (list or ndarray): The given list of feature types.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: array of features.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.dataset as ds
|
||||
>>> data_graph = ds.GraphData('dataset_file', 2)
|
||||
>>> nodes = data_graph.get_all_nodes(0)
|
||||
>>> features = data_graph.get_node_feature(nodes[0], [1])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `feature_types` is not list or ndarray.
|
||||
"""
|
||||
if isinstance(node_list, list):
|
||||
node_list = np.array(node_list, dtype=np.int32)
|
||||
return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)]
|
|
@ -1032,6 +1032,7 @@ def check_textfiledataset(method):
|
|||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_split(method):
|
||||
"""check the input arguments of split."""
|
||||
|
||||
|
@ -1072,3 +1073,85 @@ def check_split(method):
|
|||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_list_or_ndarray(param, param_name):
|
||||
if (not isinstance(param, list)) and (not hasattr(param, 'tolist')):
|
||||
raise TypeError("Wrong input type for {0}, should be list, got {1}".format(
|
||||
param_name, type(param)))
|
||||
|
||||
|
||||
def check_gnn_get_all_nodes(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function."""
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_type; required argument
|
||||
check_type(param_dict.get("node_type"), 'node_type', int)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_all_neighbors(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
check_list_or_ndarray(param_dict.get("node_list"), 'node_list')
|
||||
|
||||
# check neighbor_type; required argument
|
||||
check_type(param_dict.get("neighbor_type"), 'neighbor_type', int)
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_aligned_list(param, param_name):
|
||||
"""Check whether the structure of each member of the list is the same."""
|
||||
if not isinstance(param, list):
|
||||
raise TypeError("Parameter {0} is not a list".format(param_name))
|
||||
membor_have_list = None
|
||||
list_len = None
|
||||
for membor in param:
|
||||
if isinstance(membor, list):
|
||||
check_aligned_list(membor, param_name)
|
||||
if membor_have_list not in (None, True):
|
||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
if list_len is not None and len(membor) != list_len:
|
||||
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
membor_have_list = True
|
||||
list_len = len(membor)
|
||||
else:
|
||||
if membor_have_list not in (None, False):
|
||||
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
||||
param_name))
|
||||
membor_have_list = False
|
||||
|
||||
|
||||
def check_gnn_get_node_feature(method):
|
||||
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function."""
|
||||
@wraps(method)
|
||||
def new_method(*args, **kwargs):
|
||||
param_dict = make_param_dict(method, args, kwargs)
|
||||
|
||||
# check node_list; required argument
|
||||
node_list = param_dict.get("node_list")
|
||||
check_list_or_ndarray(node_list, 'node_list')
|
||||
if isinstance(node_list, list):
|
||||
check_aligned_list(node_list, 'node_list')
|
||||
|
||||
# check feature_types; required argument
|
||||
check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types')
|
||||
|
||||
return method(*args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -32,6 +32,7 @@ from .mindpage import MindPage
|
|||
from .shardutils import SUCCESS, FAILED
|
||||
from .tools.cifar10_to_mr import Cifar10ToMR
|
||||
from .tools.cifar100_to_mr import Cifar100ToMR
|
||||
from .tools.graph_map_schema import GraphMapSchema
|
||||
from .tools.imagenet_to_mr import ImageNetToMR
|
||||
from .tools.mnist_to_mr import MnistToMR
|
||||
|
||||
|
|
|
@ -29,9 +29,10 @@ from .common.exceptions import *
|
|||
from .shardutils import SUCCESS, FAILED
|
||||
from .tools.cifar10_to_mr import Cifar10ToMR
|
||||
from .tools.cifar100_to_mr import Cifar100ToMR
|
||||
from .tools.graph_map_schema import GraphMapSchema
|
||||
from .tools.imagenet_to_mr import ImageNetToMR
|
||||
from .tools.mnist_to_mr import MnistToMR
|
||||
|
||||
__all__ = ['FileWriter', 'FileReader', 'MindPage',
|
||||
__all__ = ['FileWriter', 'FileReader', 'MindPage', 'GraphMapSchema',
|
||||
'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR',
|
||||
'SUCCESS', 'FAILED']
|
||||
|
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License")
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Graph data convert tool for MindRecord.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['GraphMapSchema']
|
||||
|
||||
|
||||
class GraphMapSchema:
|
||||
"""
|
||||
Class is for transformation from graph data to MindRecord.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
init
|
||||
"""
|
||||
self.num_node_features = 0
|
||||
self.num_edge_features = 0
|
||||
self.union_schema_in_mindrecord = {
|
||||
"first_id": {"type": "int64"},
|
||||
"second_id": {"type": "int64"},
|
||||
"third_id": {"type": "int64"},
|
||||
"type": {"type": "int32"},
|
||||
"attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge
|
||||
"node_feature_index": {"type": "int32", "shape": [-1]},
|
||||
"edge_feature_index": {"type": "int32", "shape": [-1]}
|
||||
}
|
||||
|
||||
def get_schema(self):
|
||||
"""
|
||||
Get schema
|
||||
"""
|
||||
return self.union_schema_in_mindrecord
|
||||
|
||||
def set_node_feature_profile(self, num_features, features_data_type, features_shape):
|
||||
"""
|
||||
Set node features profile
|
||||
"""
|
||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||
raise ValueError("Node feature profile is not match.")
|
||||
|
||||
self.num_node_features = num_features
|
||||
for i in range(num_features):
|
||||
k = i + 1
|
||||
field_key = 'node_feature_' + str(k)
|
||||
field_value = {"type": features_data_type[i], "shape": features_shape[i]}
|
||||
self.union_schema_in_mindrecord[field_key] = field_value
|
||||
|
||||
def set_edge_feature_profile(self, num_features, features_data_type, features_shape):
|
||||
"""
|
||||
Set edge features profile
|
||||
"""
|
||||
if num_features != len(features_data_type) or num_features != len(features_shape):
|
||||
raise ValueError("Edge feature profile is not match.")
|
||||
|
||||
self.num_edge_features = num_features
|
||||
for i in range(num_features):
|
||||
k = i + 1
|
||||
field_key = 'edge_feature_' + str(k)
|
||||
field_value = {"type": features_data_type[i], "shape": features_shape[i]}
|
||||
self.union_schema_in_mindrecord[field_key] = field_value
|
||||
|
||||
def transform_node(self, node):
|
||||
"""
|
||||
Executes transformation from node data to union format.
|
||||
Args:
|
||||
node(schema): node's data
|
||||
Returns:
|
||||
graph data with union schema
|
||||
"""
|
||||
node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"],
|
||||
"node_feature_index": []}
|
||||
for i in range(self.num_node_features):
|
||||
k = i + 1
|
||||
node_field_key = 'feature_' + str(k)
|
||||
graph_field_key = 'node_feature_' + str(k)
|
||||
graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"]
|
||||
if node_field_key in node:
|
||||
node_graph["node_feature_index"].append(k)
|
||||
node_graph[graph_field_key] = np.reshape(np.array(node[node_field_key], dtype=graph_field_type), [-1])
|
||||
else:
|
||||
node_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1])
|
||||
|
||||
if node_graph["node_feature_index"]:
|
||||
node_graph["node_feature_index"] = np.array(node_graph["node_feature_index"], dtype="int32")
|
||||
else:
|
||||
node_graph["node_feature_index"] = np.array([-1], dtype="int32")
|
||||
|
||||
node_graph["edge_feature_index"] = np.array([-1], dtype="int32")
|
||||
for i in range(self.num_edge_features):
|
||||
k = i + 1
|
||||
graph_field_key = 'edge_feature_' + str(k)
|
||||
graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"]
|
||||
node_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1])
|
||||
return node_graph
|
||||
|
||||
def transform_edge(self, edge):
|
||||
"""
|
||||
Executes transformation from edge data to union format.
|
||||
Args:
|
||||
edge(schema): edge's data
|
||||
Returns:
|
||||
graph data with union schema
|
||||
"""
|
||||
edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e',
|
||||
"type": edge["type"], "edge_feature_index": []}
|
||||
|
||||
for i in range(self.num_edge_features):
|
||||
k = i + 1
|
||||
edge_field_key = 'feature_' + str(k)
|
||||
graph_field_key = 'edge_feature_' + str(k)
|
||||
graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"]
|
||||
if edge_field_key in edge:
|
||||
edge_graph["edge_feature_index"].append(k)
|
||||
edge_graph[graph_field_key] = np.reshape(np.array(edge[edge_field_key], dtype=graph_field_type), [-1])
|
||||
else:
|
||||
edge_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1])
|
||||
|
||||
if edge_graph["edge_feature_index"]:
|
||||
edge_graph["edge_feature_index"] = np.array(edge_graph["edge_feature_index"], dtype="int32")
|
||||
else:
|
||||
edge_graph["edge_feature_index"] = np.array([-1], dtype="int32")
|
||||
|
||||
edge_graph["node_feature_index"] = np.array([-1], dtype="int32")
|
||||
for i in range(self.num_node_features):
|
||||
k = i + 1
|
||||
graph_field_key = 'node_feature_' + str(k)
|
||||
graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"]
|
||||
edge_graph[graph_field_key] = np.array([0], dtype=graph_field_type)
|
||||
return edge_graph
|
|
@ -70,6 +70,7 @@ SET(DE_UT_SRCS
|
|||
concat_op_test.cc
|
||||
jieba_tokenizer_op_test.cc
|
||||
tokenizer_op_test.cc
|
||||
gnn_graph_test.cc
|
||||
)
|
||||
|
||||
add_executable(de_ut_tests ${DE_UT_SRCS})
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "dataset/util/status.h"
|
||||
#include "dataset/engine/gnn/node.h"
|
||||
#include "dataset/engine/gnn/graph_loader.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using namespace mindspore::dataset::gnn;
|
||||
|
||||
class MindDataTestGNNGraph : public UT::Common {
|
||||
protected:
|
||||
MindDataTestGNNGraph() = default;
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphLoader gl(path, 4);
|
||||
EXPECT_TRUE(gl.InitAndLoad().IsOk());
|
||||
NodeIdMap n_id_map;
|
||||
EdgeIdMap e_id_map;
|
||||
NodeTypeMap n_type_map;
|
||||
EdgeTypeMap e_type_map;
|
||||
NodeFeatureMap n_feature_map;
|
||||
EdgeFeatureMap e_feature_map;
|
||||
DefaultFeatureMap default_feature_map;
|
||||
EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map,
|
||||
&default_feature_map)
|
||||
.IsOk());
|
||||
EXPECT_EQ(n_id_map.size(), 20);
|
||||
EXPECT_EQ(e_id_map.size(), 20);
|
||||
EXPECT_EQ(n_type_map[2].size(), 10);
|
||||
EXPECT_EQ(n_type_map[1].size(), 10);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
Graph graph(path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
std::vector<NodeMetaInfo> node_info;
|
||||
std::vector<EdgeMetaInfo> edge_info;
|
||||
s = graph.GetMetaInfo(&node_info, &edge_info);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(node_info.size() == 2);
|
||||
|
||||
std::shared_ptr<Tensor> nodes;
|
||||
s = graph.GetNodes(node_info[1].type, -1, &nodes);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::vector<NodeIdType> node_list;
|
||||
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
|
||||
node_list.push_back(*itr);
|
||||
if (node_list.size() >= 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::shared_ptr<Tensor> neighbors;
|
||||
s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
|
||||
TensorRow features;
|
||||
s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(features.size() == 3);
|
||||
EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>");
|
||||
EXPECT_TRUE(features[0]->ToString() ==
|
||||
"Tensor (shape: <10,5>, Type: int32)\n"
|
||||
"[[0,1,0,0,0],[1,0,0,0,1],[0,0,1,1,0],[0,0,0,0,0],[1,1,0,1,0],[0,0,0,0,1],[0,1,0,0,0],[0,0,0,1,1],[0,1,1,"
|
||||
"0,0],[0,1,0,1,0]]");
|
||||
EXPECT_TRUE(features[1]->shape().ToString() == "<10>");
|
||||
EXPECT_TRUE(features[1]->ToString() ==
|
||||
"Tensor (shape: <10>, Type: float32)\n[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]");
|
||||
EXPECT_TRUE(features[2]->shape().ToString() == "<10>");
|
||||
EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import pytest
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
|
||||
|
||||
def test_graphdata_getfullneighbor():
|
||||
g = ds.GraphData(DATASET_FILE, 2)
|
||||
nodes = g.get_all_nodes(1)
|
||||
assert len(nodes) is 10
|
||||
nodes_list = nodes.tolist()
|
||||
neighbor = g.get_all_neighbors(nodes_list, 2)
|
||||
assert neighbor.shape == (10, 6)
|
||||
row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3])
|
||||
assert row_tensor[0].shape == (10, 6)
|
||||
|
||||
|
||||
def test_graphdata_getnodefeature_input_check():
|
||||
g = ds.GraphData(DATASET_FILE)
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [1, [1, 1]]
|
||||
g.get_node_feature(input_list, [1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], 1]
|
||||
g.get_node_feature(input_list, [1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [1, 1, 1]]
|
||||
g.get_node_feature(input_list, [1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1, 1], [1, 1]]
|
||||
g.get_node_feature(input_list, [1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [1, [1, 1]]]
|
||||
g.get_node_feature(input_list, [1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [[1, 1], 1]]
|
||||
g.get_node_feature(input_list, [1])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [1, 1]]
|
||||
g.get_node_feature(input_list, 1)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [1, 1]]
|
||||
g.get_node_feature(input_list, ["a"])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
input_list = [[1, 1], [1, 1]]
|
||||
g.get_node_feature(input_list, [1, "a"])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_graphdata_getfullneighbor()
|
||||
logger.info('test_graphdata_getfullneighbor Ended.\n')
|
||||
test_graphdata_getnodefeature_input_check()
|
||||
logger.info('test_graphdata_getnodefeature_input_check Ended.\n')
|
|
@ -89,6 +89,8 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file):
|
|||
num_iter = 0
|
||||
for item in data_set.create_dict_iterator():
|
||||
logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
|
||||
logger.info("-------------- item[data]: \
|
||||
{}------------------------".format(item["data"][:10]))
|
||||
logger.info("-------------- item[file_name]: \
|
||||
{}------------------------".format("".join([chr(x) for x in item["file_name"]])))
|
||||
logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
|
||||
|
|
Loading…
Reference in New Issue