diff --git a/example/graph_to_mindrecord/README.md b/example/graph_to_mindrecord/README.md new file mode 100644 index 00000000000..cc6f6a1c70b --- /dev/null +++ b/example/graph_to_mindrecord/README.md @@ -0,0 +1,73 @@ +# Guideline to Efficiently Generating MindRecord + + + +- [What does the example do](#what-does-the-example-do) +- [Example test for Cora](#example-test-for-cora) +- [How to use the example for other dataset](#how-to-use-the-example-for-other-dataset) + - [Create work space](#create-work-space) + - [Implement data generator](#implement-data-generator) + - [Run data generator](#run-data-generator) + + + + +## What does the example do + +This example provides an efficient way to generate MindRecord. Users only need to define the parallel granularity of training data reading and the data reading function of a single task. That is, they can efficiently convert the user's training data into MindRecord. + +1. write_cora.sh: entry script, users need to modify parameters according to their own training data. +2. writer.py: main script, called by write_cora.sh, it mainly reads user training data in parallel and generates MindRecord. +3. cora/mr_api.py: uers define their own parallel granularity of training data reading and single task reading function through the cora. + +## Example test for Cora + +1. Download and prepare the Cora dataset as required. + + > [Cora dataset download address](https://github.com/jzaldi/datasets/tree/master/cora) + + +2. Edit write_cora.sh and modify the parameters + ``` + --mindrecord_file: output MindRecord file. + --mindrecord_partitions: the partitions for MindRecord. + ``` + +3. Run the bash script + ```bash + bash write_cora.sh + ``` + +## How to use the example for other dataset + +### Create work space + +Assume the dataset name is 'xyz' +* Create work space from cora + ```shell + cd ${your_mindspore_home}/example/graph_to_mindrecord + cp -r cora xyz + ``` + +### Implement data generator + +Edit dictionary data generator. +* Edit file + ```shell + cd ${your_mindspore_home}/example/graph_to_mindrecord + vi xyz/mr_api.py + ``` + +Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented. +- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. +- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() + +### Run data generator + +* run python script + ```shell + cd ${your_mindspore_home}/example/graph_to_mindrecord + python writer.py --mindrecord_script xyz [...] + ``` + > You can put this command in script **write_xyz.sh** for easy execution + diff --git a/example/graph_to_mindrecord/citeseer/__init__.py b/example/graph_to_mindrecord/citeseer/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/graph_to_mindrecord/citeseer/mr_api.py b/example/graph_to_mindrecord/citeseer/mr_api.py new file mode 100644 index 00000000000..8b1f424b0a0 --- /dev/null +++ b/example/graph_to_mindrecord/citeseer/mr_api.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord GNN writer. +""" +import csv +import os + +import numpy as np +import scipy.sparse as sp + +# parse args from command line parameter 'graph_api_args' +# args delimiter is ':' +args = os.environ['graph_api_args'].split(':') +CITESEER_CONTENT_FILE = args[0] +CITESEER_CITES_FILE = args[1] +CITESEER_MINDRECRD_LABEL_FILE = CITESEER_CONTENT_FILE + "_label_mindrecord" +CITESEER_MINDRECRD_ID_MAP_FILE = CITESEER_CONTENT_FILE + "_id_mindrecord" + +node_id_map = {} + +# profile: (num_features, feature_data_types, feature_shapes) +node_profile = (2, ["float32", "int64"], [[-1], [-1]]) +edge_profile = (0, [], []) + + +def _normalize_citeseer_features(features): + features = np.array(features) + row_sum = np.array(features.sum(1)) + r_inv = np.power(row_sum * 1.0, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + features = r_mat_inv.dot(features) + return features + + +def yield_nodes(task_id=0): + """ + Generate node data + + Yields: + data (dict): data row which is dict. + """ + print("Node task is {}".format(task_id)) + label_types = {} + label_size = 0 + node_num = 0 + with open(CITESEER_CONTENT_FILE) as content_file: + content_reader = csv.reader(content_file, delimiter='\t') + line_count = 0 + for row in content_reader: + if not row[-1] in label_types: + label_types[row[-1]] = label_size + label_size += 1 + if not row[0] in node_id_map: + node_id_map[row[0]] = node_num + node_num += 1 + raw_features = [[int(x) for x in row[1:-1]]] + node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_citeseer_features(raw_features), + 'feature_2': [label_types[row[-1]]]} + yield node + line_count += 1 + print('Processed {} lines for nodes.'.format(line_count)) + # print('label types {}.'.format(label_types)) + with open(CITESEER_MINDRECRD_LABEL_FILE, 'w') as f: + for k in label_types: + print(k + ',' + str(label_types[k]), file=f) + + +def yield_edges(task_id=0): + """ + Generate edge data + + Yields: + data (dict): data row which is dict. + """ + print("Edge task is {}".format(task_id)) + # print(map_string_int) + with open(CITESEER_CITES_FILE) as cites_file: + cites_reader = csv.reader(cites_file, delimiter='\t') + line_count = 0 + for row in cites_reader: + if not row[0] in node_id_map: + print('Source node {} does not exist.'.format(row[0])) + continue + if not row[1] in node_id_map: + print('Destination node {} does not exist.'.format(row[1])) + continue + line_count += 1 + edge = {'id': line_count, + 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} + yield edge + + with open(CITESEER_MINDRECRD_ID_MAP_FILE, 'w') as f: + for k in node_id_map: + print(k + ',' + str(node_id_map[k]), file=f) + print('Processed {} lines for edges.'.format(line_count)) diff --git a/example/graph_to_mindrecord/cora/__init__.py b/example/graph_to_mindrecord/cora/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/example/graph_to_mindrecord/cora/mr_api.py b/example/graph_to_mindrecord/cora/mr_api.py new file mode 100644 index 00000000000..0963fd78f7d --- /dev/null +++ b/example/graph_to_mindrecord/cora/mr_api.py @@ -0,0 +1,113 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord GNN writer. +""" +import csv +import os + +import numpy as np +import scipy.sparse as sp + +# parse args from command line parameter 'graph_api_args' +# args delimiter is ':' +args = os.environ['graph_api_args'].split(':') +CORA_CONTENT_FILE = args[0] +CORA_CITES_FILE = args[1] +CORA_MINDRECRD_LABEL_FILE = CORA_CONTENT_FILE + "_label_mindrecord" +CORA_CONTENT_ID_MAP_FILE = CORA_CONTENT_FILE + "_id_mindrecord" + +node_id_map = {} + +# profile: (num_features, feature_data_types, feature_shapes) +node_profile = (2, ["float32", "int64"], [[-1], [-1]]) +edge_profile = (0, [], []) + + +def _normalize_cora_features(features): + features = np.array(features) + row_sum = np.array(features.sum(1)) + r_inv = np.power(row_sum * 1.0, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + features = r_mat_inv.dot(features) + return features + + +def yield_nodes(task_id=0): + """ + Generate node data + + Yields: + data (dict): data row which is dict. + """ + print("Node task is {}".format(task_id)) + label_types = {} + label_size = 0 + node_num = 0 + with open(CORA_CONTENT_FILE) as content_file: + content_reader = csv.reader(content_file, delimiter=',') + line_count = 0 + for row in content_reader: + if line_count == 0: + line_count += 1 + continue + if not row[0] in node_id_map: + node_id_map[row[0]] = node_num + node_num += 1 + if not row[-1] in label_types: + label_types[row[-1]] = label_size + label_size += 1 + raw_features = [[int(x) for x in row[1:-1]]] + node = {'id': node_id_map[row[0]], 'type': 0, 'feature_1': _normalize_cora_features(raw_features), + 'feature_2': [label_types[row[-1]]]} + yield node + line_count += 1 + print('Processed {} lines for nodes.'.format(line_count)) + print('label types {}.'.format(label_types)) + with open(CORA_MINDRECRD_LABEL_FILE, 'w') as f: + for k in label_types: + print(k + ',' + str(label_types[k]), file=f) + + +def yield_edges(task_id=0): + """ + Generate edge data + + Yields: + data (dict): data row which is dict. + """ + print("Edge task is {}".format(task_id)) + with open(CORA_CITES_FILE) as cites_file: + cites_reader = csv.reader(cites_file, delimiter=',') + line_count = 0 + for row in cites_reader: + if line_count == 0: + line_count += 1 + continue + if not row[0] in node_id_map: + print('Source node {} does not exist.'.format(row[0])) + continue + if not row[1] in node_id_map: + print('Destination node {} does not exist.'.format(row[1])) + continue + edge = {'id': line_count, + 'src_id': node_id_map[row[0]], 'dst_id': node_id_map[row[1]], 'type': 0} + yield edge + line_count += 1 + print('Processed {} lines for edges.'.format(line_count)) + with open(CORA_CONTENT_ID_MAP_FILE, 'w') as f: + for k in node_id_map: + print(k + ',' + str(node_id_map[k]), file=f) diff --git a/example/graph_to_mindrecord/read_citeseer.sh b/example/graph_to_mindrecord/read_citeseer.sh new file mode 100644 index 00000000000..5d79a2ee96c --- /dev/null +++ b/example/graph_to_mindrecord/read_citeseer.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python reader.py --path "/tmp/citeseer/mindrecord/citeseer_mr" diff --git a/example/graph_to_mindrecord/read_cora.sh b/example/graph_to_mindrecord/read_cora.sh new file mode 100644 index 00000000000..94b95650504 --- /dev/null +++ b/example/graph_to_mindrecord/read_cora.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python reader.py --path "/tmp/cora/mindrecord/cora_mr" diff --git a/example/graph_to_mindrecord/reader.py b/example/graph_to_mindrecord/reader.py new file mode 100644 index 00000000000..637ce41a171 --- /dev/null +++ b/example/graph_to_mindrecord/reader.py @@ -0,0 +1,34 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py --mindrecord_script /YourScriptPath ... +""" +import argparse + +import mindspore.dataset as ds + +parser = argparse.ArgumentParser(description='Mind record reader') +parser.add_argument('--path', type=str, default="/tmp/cora/mindrecord/cora_mr", + help='data file') +args = parser.parse_args() + +data_set = ds.MindDataset(args.path) +num_iter = 0 +for item in data_set.create_dict_iterator(): + print(item) + num_iter += 1 +print("Total items # is {}".format(num_iter)) diff --git a/example/graph_to_mindrecord/write_citeseer.sh b/example/graph_to_mindrecord/write_citeseer.sh new file mode 100644 index 00000000000..0d5093f18d2 --- /dev/null +++ b/example/graph_to_mindrecord/write_citeseer.sh @@ -0,0 +1,9 @@ +#!/bin/bash +rm /tmp/citeseer/mindrecord/* + +python writer.py --mindrecord_script citeseer \ +--mindrecord_file "/tmp/citeseer/mindrecord/citeseer_mr" \ +--mindrecord_partitions 1 \ +--mindrecord_header_size_by_bit 18 \ +--mindrecord_page_size_by_bit 20 \ +--graph_api_args "/tmp/citeseer/dataset/citeseer.content:/tmp/citeseer/dataset/citeseer.cites" diff --git a/example/graph_to_mindrecord/write_cora.sh b/example/graph_to_mindrecord/write_cora.sh new file mode 100644 index 00000000000..6ba321ef035 --- /dev/null +++ b/example/graph_to_mindrecord/write_cora.sh @@ -0,0 +1,9 @@ +#!/bin/bash +rm /tmp/cora/mindrecord/* + +python writer.py --mindrecord_script cora \ +--mindrecord_file "/tmp/cora/mindrecord/cora_mr" \ +--mindrecord_partitions 1 \ +--mindrecord_header_size_by_bit 18 \ +--mindrecord_page_size_by_bit 20 \ +--graph_api_args "/tmp/cora/dataset/cora_content.csv:/tmp/cora/dataset/cora_cites.csv" diff --git a/example/graph_to_mindrecord/writer.py b/example/graph_to_mindrecord/writer.py new file mode 100644 index 00000000000..4ada8d60ea8 --- /dev/null +++ b/example/graph_to_mindrecord/writer.py @@ -0,0 +1,186 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py --mindrecord_script /YourScriptPath ... +""" +import argparse +import os +import time +from importlib import import_module +from multiprocessing import Pool + +from mindspore.mindrecord import FileWriter +from mindspore.mindrecord import GraphMapSchema + + +def exec_task(task_id, parallel_writer=True): + """ + Execute task with specified task id + """ + print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) + imagenet_iter = mindrecord_dict_data(task_id) + batch_size = 512 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data = imagenet_iter.__next__() + if 'dst_id' in data: + data = graph_map_schema.transform_edge(data) + else: + data = graph_map_schema.transform_node(data) + data_list.append(data) + transform_count += 1 + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + break + + +def read_args(): + """ + read args + """ + parser = argparse.ArgumentParser(description='Mind record writer') + parser.add_argument('--mindrecord_script', type=str, default="template", + help='path where script is saved') + + parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord/xyz", + help='written file name prefix') + + parser.add_argument('--mindrecord_partitions', type=int, default=1, + help='number of written files') + + parser.add_argument('--mindrecord_header_size_by_bit', type=int, default=24, + help='mindrecord file header size') + + parser.add_argument('--mindrecord_page_size_by_bit', type=int, default=25, + help='mindrecord file page size') + + parser.add_argument('--mindrecord_workers', type=int, default=8, + help='number of parallel workers') + + parser.add_argument('--num_node_tasks', type=int, default=1, + help='number of node tasks') + + parser.add_argument('--num_edge_tasks', type=int, default=1, + help='number of node tasks') + + parser.add_argument('--graph_api_args', type=str, default="/tmp/nodes.csv:/tmp/edges.csv", + help='nodes and edges data file, csv format with header.') + + ret_args = parser.parse_args() + + return ret_args + + +def init_writer(mr_schema): + """ + init writer + """ + print("Init writer ...") + mr_writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) + + # set the header size + if args.mindrecord_header_size_by_bit != 24: + header_size = 1 << args.mindrecord_header_size_by_bit + mr_writer.set_header_size(header_size) + + # set the page size + if args.mindrecord_page_size_by_bit != 25: + page_size = 1 << args.mindrecord_page_size_by_bit + mr_writer.set_page_size(page_size) + + # create the schema + mr_writer.add_schema(mr_schema, "mindrecord_graph_schema") + + # open file and set header + mr_writer.open_and_set_header() + + return mr_writer + + +def run_parallel_workers(num_tasks): + """ + run parallel workers + """ + # set number of workers + num_workers = args.mindrecord_workers + + task_list = list(range(num_tasks)) + + if num_workers > num_tasks: + num_workers = num_tasks + + if os.name == 'nt': + for window_task_id in task_list: + exec_task(window_task_id, False) + elif num_tasks > 1: + with Pool(num_workers) as p: + p.map(exec_task, task_list) + else: + exec_task(0, False) + + +if __name__ == "__main__": + args = read_args() + print(args) + + start_time = time.time() + + # pass mr_api arguments + os.environ['graph_api_args'] = args.graph_api_args + + # import mr_api + try: + mr_api = import_module(args.mindrecord_script + '.mr_api') + except ModuleNotFoundError: + raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) + + # init graph schema + graph_map_schema = GraphMapSchema() + + num_features, feature_data_types, feature_shapes = mr_api.node_profile + graph_map_schema.set_node_feature_profile(num_features, feature_data_types, feature_shapes) + + num_features, feature_data_types, feature_shapes = mr_api.edge_profile + graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes) + + graph_schema = graph_map_schema.get_schema() + + # init writer + writer = init_writer(graph_schema) + + # write nodes data + mindrecord_dict_data = mr_api.yield_nodes + run_parallel_workers(args.num_node_tasks) + + # write edges data + mindrecord_dict_data = mr_api.yield_edges + run_parallel_workers(args.num_edge_tasks) + + # writer wrap up + ret = writer.commit() + + end_time = time.time() + print("--------------------------------------------") + print("END. Total time: {}".format(end_time - start_time)) + print("--------------------------------------------") diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 6abd9286c24..2876c2c06b3 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -66,6 +66,7 @@ set(submodules $ $ $ + $ $ $ $ diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index bde85bf78ef..0a0e8c364e3 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -61,6 +61,7 @@ #include "dataset/engine/jagged_connector.h" #include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/engine/datasetops/source/voc_op.h" +#include "dataset/engine/gnn/graph.h" #include "dataset/kernels/data/to_float16_op.h" #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" @@ -513,6 +514,33 @@ void bindVocabObjects(py::module *m) { }); } +void bindGraphData(py::module *m) { + (void)py::class_>(*m, "Graph") + .def(py::init([](std::string dataset_file, int32_t num_workers) { + std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); + THROW_IF_ERROR(g_out->Init()); + return g_out; + })) + .def("get_nodes", + [](gnn::Graph &g, gnn::NodeType node_type, gnn::NodeIdType node_num) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetNodes(node_type, node_num, &out)); + return out; + }) + .def("get_all_neighbors", + [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { + std::shared_ptr out; + THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); + return out; + }) + .def("get_node_feature", + [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { + TensorRow out; + THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); + return out; + }); +} + // This is where we externalize the C logic as python modules PYBIND11_MODULE(_c_dataengine, m) { m.doc() = "pybind11 for _c_dataengine"; @@ -578,6 +606,7 @@ PYBIND11_MODULE(_c_dataengine, m) { bindDatasetOps(&m); bindInfoObjects(&m); bindVocabObjects(&m); + bindGraphData(&m); } } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/CMakeLists.txt index 9d01fca9143..e7b5e682f38 100644 --- a/mindspore/ccsrc/dataset/engine/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(datasetops) add_subdirectory(opt) +add_subdirectory(gnn) if (ENABLE_TDTQUE) add_subdirectory(tdt) endif () @@ -15,7 +16,7 @@ add_library(engine OBJECT target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS}) if (ENABLE_TDTQUE) - add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt) + add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn) else() - add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt) + add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn) endif () diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index e7ed0e12a3f..49c7e78a609 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -112,10 +112,10 @@ Status MindRecordOp::Init() { data_schema_ = std::make_unique(); - std::vector col_names = shard_reader_->get_shard_column()->GetColumnName(); + std::vector col_names = shard_reader_->GetShardColumn()->GetColumnName(); CHECK_FAIL_RETURN_UNEXPECTED(!col_names.empty(), "No schema found"); - std::vector col_data_types = shard_reader_->get_shard_column()->GeColumnDataType(); - std::vector> col_shapes = shard_reader_->get_shard_column()->GetColumnShape(); + std::vector col_data_types = shard_reader_->GetShardColumn()->GeColumnDataType(); + std::vector> col_shapes = shard_reader_->GetShardColumn()->GetColumnShape(); bool load_all_cols = columns_to_load_.empty(); // if columns_to_load_ is empty it means load everything std::map colname_to_ind; @@ -296,8 +296,7 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector column_shape; // Get column data - - auto has_column = shard_reader_->get_shard_column()->GetColumnValueByName( + auto has_column = shard_reader_->GetShardColumn()->GetColumnValueByName( column_name, columns_blob, columns_json, &data, &data_ptr, &n_bytes, &column_data_type, &column_data_type_size, &column_shape); if (has_column == MSRStatus::FAILED) { diff --git a/mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt new file mode 100644 index 00000000000..d7a295e32a8 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(engine-gnn OBJECT + graph.cc + graph_loader.cc + local_node.cc + local_edge.cc + feature.cc + ) diff --git a/mindspore/ccsrc/dataset/engine/gnn/edge.h b/mindspore/ccsrc/dataset/engine/gnn/edge.h new file mode 100644 index 00000000000..47314d97c24 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/edge.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_EDGE_H_ +#define DATASET_ENGINE_GNN_EDGE_H_ + +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/engine/gnn/feature.h" +#include "dataset/engine/gnn/node.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using EdgeType = int8_t; +using EdgeIdType = int32_t; + +class Edge { + public: + // Constructor + // @param EdgeIdType id - edge id + // @param EdgeType type - edge type + // @param std::shared_ptr src_node - source node + // @param std::shared_ptr dst_node - destination node + Edge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) + : id_(id), type_(type), src_node_(src_node), dst_node_(dst_node) {} + + virtual ~Edge() = default; + + // @return NodeIdType - Returned edge id + EdgeIdType id() const { return id_; } + + // @return NodeIdType - Returned edge type + EdgeType type() const { return type_; } + + // Get the feature of a edge + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; + + // Get nodes on the edge + // @param std::pair, std::shared_ptr> *out_node - Source and destination nodes returned + Status GetNode(std::pair, std::shared_ptr> *out_node) { + *out_node = std::make_pair(src_node_, dst_node_); + return Status::OK(); + } + + // Set node to edge + // @param const std::pair, std::shared_ptr> &in_node - + Status SetNode(const std::pair, std::shared_ptr> &in_node) { + src_node_ = in_node.first; + dst_node_ = in_node.second; + return Status::OK(); + } + + // Update feature of edge + // @param std::shared_ptr feature - + // @return Status - The error code return + virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; + + protected: + EdgeIdType id_; + EdgeType type_; + std::shared_ptr src_node_; + std::shared_ptr dst_node_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_EDGE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/feature.cc b/mindspore/ccsrc/dataset/engine/gnn/feature.cc new file mode 100644 index 00000000000..e4579478217 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/feature.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +Feature::Feature(FeatureType type_name, std::shared_ptr value) : type_name_(type_name), value_(value) {} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/feature.h b/mindspore/ccsrc/dataset/engine/gnn/feature.h new file mode 100644 index 00000000000..956aba49b11 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/feature.h @@ -0,0 +1,50 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_FEATURE_H_ +#define DATASET_ENGINE_GNN_FEATURE_H_ + +#include + +#include "dataset/core/tensor.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using FeatureType = int16_t; + +class Feature { + public: + // Constructor + // @param FeatureType type_name - feature type + // @param std::shared_ptr value - feature value + Feature(FeatureType type_name, std::shared_ptr value); + + // Get feature value + // @return std::shared_ptr *out_value - feature value + const std::shared_ptr Value() const { return value_; } + + // @return NodeIdType - Returned feature type + FeatureType type() const { return type_name_; } + + private: + FeatureType type_name_; + std::shared_ptr value_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_FEATURE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/dataset/engine/gnn/graph.cc new file mode 100644 index 00000000000..9dcca723396 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.cc @@ -0,0 +1,251 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/gnn/graph.h" + +#include +#include +#include +#include + +#include "dataset/core/tensor_shape.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +Graph::Graph(std::string dataset_file, int32_t num_workers) : dataset_file_(dataset_file), num_workers_(num_workers) { + MS_LOG(INFO) << "num_workers:" << num_workers; +} + +Status Graph::GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr *out) { + auto itr = node_type_map_.find(node_type); + if (itr == node_type_map_.end()) { + std::string err_msg = "Invalid node type:" + std::to_string(node_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + if (node_num == -1) { + RETURN_IF_NOT_OK(CreateTensorByVector({itr->second}, DataType(DataType::DE_INT32), out)); + } else { + } + } + return Status::OK(); +} + +template +Status Graph::CreateTensorByVector(const std::vector> &data, DataType type, + std::shared_ptr *out) { + if (!type.IsCompatible()) { + RETURN_STATUS_UNEXPECTED("Data type not compatible"); + } + if (data.empty()) { + RETURN_STATUS_UNEXPECTED("Input data is emply"); + } + std::shared_ptr tensor; + size_t m = data.size(); + size_t n = data[0].size(); + RETURN_IF_NOT_OK(Tensor::CreateTensor( + &tensor, TensorImpl::kFlexible, TensorShape({static_cast(m), static_cast(n)}), type, nullptr)); + T *ptr = reinterpret_cast(tensor->GetMutableBuffer()); + for (auto id_m : data) { + CHECK_FAIL_RETURN_UNEXPECTED(id_m.size() == n, "Each member of the vector has a different size"); + for (auto id_n : id_m) { + *ptr = id_n; + ptr++; + } + } + tensor->Squeeze(); + *out = std::move(tensor); + return Status::OK(); +} + +template +Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { + if (!data || data->empty()) { + RETURN_STATUS_UNEXPECTED("Input data is emply"); + } + for (std::vector &vec : *data) { + size_t size = vec.size(); + if (size > max_size) { + RETURN_STATUS_UNEXPECTED("The max_size parameter is abnormal"); + } else { + for (size_t i = 0; i < (max_size - size); ++i) { + vec.push_back(default_value); + } + } + } + return Status::OK(); +} + +Status Graph::GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr *out) { return Status::OK(); } + +Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) { + if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { + std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } + + std::vector> neighbors; + size_t max_neighbor_num = 0; + neighbors.resize(node_list.size()); + for (size_t i = 0; i < node_list.size(); ++i) { + auto itr = node_id_map_.find(node_list[i]); + if (itr != node_id_map_.end()) { + RETURN_IF_NOT_OK(itr->second->GetNeighbors(neighbor_type, -1, &neighbors[i])); + max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); + } else { + std::string err_msg = "Invalid node id:" + std::to_string(node_list[i]); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + + RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); + + return Status::OK(); +} + +Status Graph::GetSampledNeighbor(const std::vector &node_list, const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) { + return Status::OK(); +} + +Status Graph::GetNegSampledNeighbor(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) { + return Status::OK(); +} + +Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, + float q, NodeIdType default_node, std::shared_ptr *out) { + return Status::OK(); +} + +Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_feature_map_.find(feature_type); + if (itr == default_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out) { + if (!nodes || nodes->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Inpude nodes is empty"); + } + TensorRow tensors; + for (auto f_type : feature_types) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetNodeDefaultFeature(f_type, &default_feature)); + + TensorShape shape(default_feature->Value()->shape()); + auto shape_vec = nodes->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&fea_tensor, TensorImpl::kFlexible, shape, default_feature->Value()->type(), nullptr)); + + dsize_t index = 0; + for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { + auto itr = node_id_map_.find(*node_itr); + std::shared_ptr feature; + if (itr != node_id_map_.end()) { + if (!itr->second->GetFeatures(f_type, &feature).IsOk()) { + feature = default_feature; + } + } else { + if (*node_itr == kDefaultNodeId) { + feature = default_feature; + } else { + std::string err_msg = "Invalid node id:" + std::to_string(*node_itr); + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value())); + index++; + } + + TensorShape reshape(nodes->shape()); + for (auto s : default_feature->Value()->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + tensors.push_back(fea_tensor); + } + *out = std::move(tensors); + return Status::OK(); +} + +Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, + TensorRow *out) { + return Status::OK(); +} + +Status Graph::Init() { + RETURN_IF_NOT_OK(LoadNodeAndEdge()); + return Status::OK(); +} + +Status Graph::GetMetaInfo(std::vector *node_info, std::vector *edge_info) { + node_info->reserve(node_type_map_.size()); + for (auto node : node_type_map_) { + NodeMetaInfo n_info; + n_info.type = node.first; + n_info.num = node.second.size(); + auto itr = node_feature_map_.find(node.first); + if (itr != node_feature_map_.end()) { + for (auto f_type : itr->second) { + n_info.feature_type.push_back(f_type); + } + std::sort(n_info.feature_type.begin(), n_info.feature_type.end()); + } + node_info->push_back(n_info); + } + + edge_info->reserve(edge_type_map_.size()); + for (auto edge : edge_type_map_) { + EdgeMetaInfo e_info; + e_info.type = edge.first; + e_info.num = edge.second.size(); + auto itr = edge_feature_map_.find(edge.first); + if (itr != edge_feature_map_.end()) { + for (auto f_type : itr->second) { + e_info.feature_type.push_back(f_type); + } + } + edge_info->push_back(e_info); + } + return Status::OK(); +} + +Status Graph::LoadNodeAndEdge() { + GraphLoader gl(dataset_file_, num_workers_); + // ask graph_loader to load everything into memory + RETURN_IF_NOT_OK(gl.InitAndLoad()); + // get all maps + RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, + &node_feature_map_, &edge_feature_map_, &default_feature_map_)); + return Status::OK(); +} +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph.h b/mindspore/ccsrc/dataset/engine/gnn/graph.h new file mode 100644 index 00000000000..027ba53aeb1 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/graph.h @@ -0,0 +1,166 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_GRAPH_H_ +#define DATASET_ENGINE_GNN_GRAPH_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/engine/gnn/graph_loader.h" +#include "dataset/engine/gnn/feature.h" +#include "dataset/engine/gnn/node.h" +#include "dataset/engine/gnn/edge.h" +#include "dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +struct NodeMetaInfo { + NodeType type; + NodeIdType num; + std::vector feature_type; + NodeMetaInfo() { + type = 0; + num = 0; + } +}; + +struct EdgeMetaInfo { + EdgeType type; + EdgeIdType num; + std::vector feature_type; + EdgeMetaInfo() { + type = 0; + num = 0; + } +}; + +class Graph { + public: + // Constructor + // @param std::string dataset_file - + // @param int32_t num_workers - number of parallel threads + Graph(std::string dataset_file, int32_t num_workers); + + ~Graph() = default; + + // Get the nodes from the graph. + // @param NodeType node_type - type of node + // @param NodeIdType node_num - Number of nodes to be acquired, if -1 means all nodes are acquired + // @param std::shared_ptr *out - Returned nodes id + // @return Status - The error code return + Status GetNodes(NodeType node_type, NodeIdType node_num, std::shared_ptr *out); + + // Get the edges from the graph. + // @param NodeType edge_type - type of edge + // @param NodeIdType edge_num - Number of edges to be acquired, if -1 means all edges are acquired + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The error code return + Status GetEdges(EdgeType edge_type, EdgeIdType edge_num, std::shared_ptr *out); + + // All neighbors of the acquisition node. + // @param std::vector node_list - List of nodes + // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is + // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors + // is not enough, fill in tensor as -1. + // @return Status - The error code return + Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out); + + Status GetSampledNeighbor(const std::vector &node_list, const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out); + Status GetNegSampledNeighbor(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out); + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, float p, float q, + NodeIdType default_node, std::shared_ptr *out); + + // Get the feature of a node + // @param std::shared_ptr nodes - List of nodes + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param TensorRow *out - Returned features + // @return Status - The error code return + Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out); + + // Get the feature of a edge + // @param std::shared_ptr edget - List of edges + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param Tensor *out - Returned features + // @return Status - The error code return + Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, + TensorRow *out); + + // Get meta information of graph + // @param std::vector *node_info - Returned meta information of node + // @param std::vector *node_info - Returned meta information of edge + // @return Status - The error code return + Status GetMetaInfo(std::vector *node_info, std::vector *edge_info); + + Status Init(); + + private: + // Load graph data from mindrecord file + // @return Status - The error code return + Status LoadNodeAndEdge(); + + // Create Tensor By Vector + // @param std::vector> &data - + // @param DataType type - + // @param std::shared_ptr *out - + // @return Status - The error code return + template + Status CreateTensorByVector(const std::vector> &data, DataType type, std::shared_ptr *out); + + // Complete vector + // @param std::vector> *data - To be completed vector + // @param size_t max_size - The size of the completed vector + // @param T default_value - Filled default + // @return Status - The error code return + template + Status ComplementVector(std::vector> *data, size_t max_size, T default_value); + + // Get the default feature of a node + // @param FeatureType feature_type - + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + std::string dataset_file_; + int32_t num_workers_; // The number of worker threads + + std::unordered_map> node_type_map_; + std::unordered_map> node_id_map_; + + std::unordered_map> edge_type_map_; + std::unordered_map> edge_id_map_; + + std::unordered_map> node_feature_map_; + std::unordered_map> edge_feature_map_; + + std::unordered_map> default_feature_map_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_GRAPH_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc new file mode 100644 index 00000000000..127769bd68a --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc @@ -0,0 +1,248 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "dataset/engine/gnn/graph_loader.h" +#include "mindspore/ccsrc/mindrecord/include/shard_error.h" +#include "dataset/engine/gnn/local_edge.h" +#include "dataset/engine/gnn/local_node.h" + +using ShardTuple = std::vector, mindspore::mindrecord::json>>; + +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::MSRStatus; + +GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) + : mr_path_(mr_filepath), + num_workers_(num_workers), + row_id_(0), + keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} + +Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, + EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, + EdgeFeatureMap *e_feature_map, DefaultFeatureMap *default_feature_map) { + for (std::deque> &dq : n_deques_) { + while (dq.empty() == false) { + std::shared_ptr node_ptr = dq.front(); + n_id_map->insert({node_ptr->id(), node_ptr}); + (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); + dq.pop_front(); + } + } + + for (std::deque> &dq : e_deques_) { + while (dq.empty() == false) { + std::shared_ptr edge_ptr = dq.front(); + std::pair, std::shared_ptr> p; + RETURN_IF_NOT_OK(edge_ptr->GetNode(&p)); + auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id()); + CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id:" + std::to_string(src_itr->first)); + CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id:" + std::to_string(dst_itr->first)); + RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); + RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); + e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ + (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); + dq.pop_front(); + } + } + + for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); + for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); + + MergeFeatureMaps(n_feature_map, e_feature_map, default_feature_map); + return Status::OK(); +} + +Status GraphLoader::InitAndLoad() { + CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n"); + CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n"); + n_deques_.resize(num_workers_); + e_deques_.resize(num_workers_); + n_feature_maps_.resize(num_workers_); + e_feature_maps_.resize(num_workers_); + default_feature_maps_.resize(num_workers_); + std::vector> r_codes(num_workers_); + + shard_reader_ = std::make_unique(); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Open({mr_path_}, true, num_workers_) == MSRStatus::SUCCESS, + "Fail to open" + mr_path_); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); + + mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; + for (const std::string &key : keys_) { + if (schema.find(key) == schema.end()) { + RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); + } + } + + // launching worker threads + for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { + r_codes[wkr_id] = std::async(std::launch::async, &GraphLoader::WorkerEntry, this, wkr_id); + } + // wait for threads to finish and check its return code + for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { + RETURN_IF_NOT_OK(r_codes[wkr_id].get()); + } + return Status::OK(); +} + +Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrecord::json &col_jsn, + std::shared_ptr *node, NodeFeatureMap *feature_map, + DefaultFeatureMap *default_feature) { + NodeIdType node_id = col_jsn["first_id"]; + NodeType node_type = static_cast(col_jsn["type"]); + (*node) = std::make_shared(node_id, node_type); + std::vector indices; + RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); + + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); + RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[node_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } + return Status::OK(); +} + +Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrecord::json &col_jsn, + std::shared_ptr *edge, EdgeFeatureMap *feature_map, + DefaultFeatureMap *default_feature) { + EdgeIdType edge_id = col_jsn["first_id"]; + EdgeType edge_type = static_cast(col_jsn["type"]); + NodeIdType src_id = col_jsn["second_id"], dst_id = col_jsn["third_id"]; + std::shared_ptr src = std::make_shared(src_id, -1); + std::shared_ptr dst = std::make_shared(dst_id, -1); + (*edge) = std::make_shared(edge_id, edge_type, src, dst); + std::vector indices; + RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); + RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[edge_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateTensor(&zero_tensor, TensorImpl::kFlexible, tensor->shape(), tensor->type())); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } + return Status::OK(); +} + +Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, + const mindrecord::json &col_jsn, std::shared_ptr *tensor) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( + key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, TensorImpl::kFlexible, + std::move(TensorShape({static_cast(n_bytes / col_type_size)})), + std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), data)); + return Status::OK(); +} + +Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, + const mindrecord::json &col_jsn, std::vector *indices) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( + key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); + + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + + for (int i = 0; i < n_bytes; i += col_type_size) { + int32_t feature_ind = -1; + if (col_type == mindrecord::ColumnInt32) { + feature_ind = *(reinterpret_cast(data + i)); + } else if (col_type == mindrecord::ColumnInt64) { + feature_ind = *(reinterpret_cast(data + i)); + } else { + RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); + } + if (feature_ind >= 0) indices->push_back(feature_ind); + } + return Status::OK(); +} + +Status GraphLoader::WorkerEntry(int32_t worker_id) { + ShardTuple rows = shard_reader_->GetNextById(row_id_++, worker_id); + while (rows.empty() == false) { + for (const auto &tupled_row : rows) { + std::vector col_blob = std::get<0>(tupled_row); + mindrecord::json col_jsn = std::get<1>(tupled_row); + std::string attr = col_jsn["attribute"]; + if (attr == "n") { + std::shared_ptr node_ptr; + RETURN_IF_NOT_OK( + LoadNode(col_blob, col_jsn, &node_ptr, &(n_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); + n_deques_[worker_id].emplace_back(node_ptr); + } else if (attr == "e") { + std::shared_ptr edge_ptr; + RETURN_IF_NOT_OK( + LoadEdge(col_blob, col_jsn, &edge_ptr, &(e_feature_maps_[worker_id]), &default_feature_maps_[worker_id])); + e_deques_[worker_id].emplace_back(edge_ptr); + } else { + MS_LOG(WARNING) << "attribute:" << attr << " is neither edge nor node."; + } + } + rows = shard_reader_->GetNextById(row_id_++, worker_id); + } + return Status::OK(); +} + +void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, + DefaultFeatureMap *default_feature_map) { + for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { + for (auto &m : n_feature_maps_[wkr_id]) { + for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); + } + for (auto &m : e_feature_maps_[wkr_id]) { + for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); + } + for (auto &m : default_feature_maps_[wkr_id]) { + (*default_feature_map)[m.first] = m.second; + } + } + n_feature_maps_.clear(); + e_feature_maps_.clear(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h new file mode 100644 index 00000000000..7e5ccdd35f9 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/graph_loader.h @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_GRAPH_LOADER_H_ +#define DATASET_ENGINE_GNN_GRAPH_LOADER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "dataset/core/data_type.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/gnn/feature.h" +#include "dataset/engine/gnn/graph.h" +#include "dataset/engine/gnn/node.h" +#include "dataset/engine/gnn/edge.h" +#include "dataset/util/status.h" +#include "mindrecord/include/shard_reader.h" +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::ShardReader; +using NodeIdMap = std::unordered_map>; +using EdgeIdMap = std::unordered_map>; +using NodeTypeMap = std::unordered_map>; +using EdgeTypeMap = std::unordered_map>; +using NodeFeatureMap = std::unordered_map>; +using EdgeFeatureMap = std::unordered_map>; +using DefaultFeatureMap = std::unordered_map>; + +// this class interfaces with the underlying storage format (mindrecord) +// it returns raw nodes and edges via GetNodesAndEdges +// it is then the responsibility of graph to construct itself based on the nodes and edges +// if needed, this class could become a base where each derived class handles a specific storage format +class GraphLoader { + public: + explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); + + // Init mindrecord and load everything into memory multi-threaded + // @return Status - the status code + Status InitAndLoad(); + + // this function will query mindrecord and construct all nodes and edges + // nodes and edges are added to map without any connection. That's because there nodes and edges are read in + // random order. src_node and dst_node in Edge are node_id only with -1 as type. + // features attached to each node and edge are expected to be filled correctly + Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, + DefaultFeatureMap *); + + private: + // + // worker thread that reads mindrecord file + // @param int32_t worker_id - id of each worker + // @return Status - the status code + Status WorkerEntry(int32_t worker_id); + + // Load a node based on 1 row of mindrecord, returns a shared_ptr + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *node - return value + // @param NodeFeatureMap *feature_map - + // @param DefaultFeatureMap *default_feature - + // @return Status - the status code + Status LoadNode(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *node, + NodeFeatureMap *feature_map, DefaultFeatureMap *default_feature); + + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *edge - return value, the edge ptr, edge is not yet connected + // @param FeatureMap *feature_map + // @param DefaultFeatureMap *default_feature - + // @return Status - the status code + Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, + EdgeFeatureMap *feature_map, DefaultFeatureMap *default_feature); + + // @param std::string key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::vector *ind - return value, list of feature index in int32_t + // @return Status - the status code + Status LoadFeatureIndex(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, + std::vector *ind); + + // @param std::string &key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param mindrecord::json &jsn - contains raw data + // @param std::shared_ptr *tensor - return value feature tensor + // @return Status - the status code + Status LoadFeatureTensor(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, + std::shared_ptr *tensor); + + // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 + void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultFeatureMap *); + + const int32_t num_workers_; + std::atomic_int row_id_; + std::string mr_path_; + std::unique_ptr shard_reader_; + std::vector>> n_deques_; + std::vector>> e_deques_; + std::vector n_feature_maps_; + std::vector e_feature_maps_; + std::vector default_feature_maps_; + const std::vector keys_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_GRAPH_LOADER_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc new file mode 100644 index 00000000000..7465b689d5d --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/local_edge.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/gnn/local_edge.h" + +#include + +namespace mindspore { +namespace dataset { +namespace gnn { + +LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node) + : Edge(id, type, src_node, dst_node) {} + +Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = features_.find(feature_type); + if (itr != features_.end()) { + *out_feature = itr->second; + return Status::OK(); + } else { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status LocalEdge::UpdateFeature(const std::shared_ptr &feature) { + auto itr = features_.find(feature->type()); + if (itr != features_.end()) { + RETURN_STATUS_UNEXPECTED("Feature already exists"); + } else { + features_[feature->type()] = feature; + return Status::OK(); + } +} +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/dataset/engine/gnn/local_edge.h new file mode 100644 index 00000000000..a34fc003739 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/local_edge.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_LOCAL_EDGE_H_ +#define DATASET_ENGINE_GNN_LOCAL_EDGE_H_ + +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/engine/gnn/edge.h" +#include "dataset/engine/gnn/feature.h" +#include "dataset/engine/gnn/node.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class LocalEdge : public Edge { + public: + // Constructor + // @param EdgeIdType id - edge id + // @param EdgeType type - edge type + // @param std::shared_ptr src_node - source node + // @param std::shared_ptr dst_node - destination node + LocalEdge(EdgeIdType id, EdgeType type, std::shared_ptr src_node, std::shared_ptr dst_node); + + ~LocalEdge() = default; + + // Get the feature of a edge + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; + + // Update feature of edge + // @param std::shared_ptr feature - + // @return Status - The error code return + Status UpdateFeature(const std::shared_ptr &feature) override; + + private: + std::unordered_map> features_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_LOCAL_EDGE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.cc b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc new file mode 100644 index 00000000000..24e865dff77 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/gnn/local_node.h" + +#include +#include +#include + +#include "dataset/engine/gnn/edge.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +LocalNode::LocalNode(NodeIdType id, NodeType type) : Node(id, type) {} + +Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = features_.find(feature_type); + if (itr != features_.end()) { + *out_feature = itr->second; + return Status::OK(); + } else { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } +} + +Status LocalNode::GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector *out_neighbors) { + std::vector neighbors; + auto itr = neighbor_nodes_.find(neighbor_type); + if (itr != neighbor_nodes_.end()) { + if (samples_num == -1) { + // Return all neighbors + neighbors.resize(itr->second.size() + 1); + neighbors[0] = id_; + std::transform(itr->second.begin(), itr->second.end(), neighbors.begin() + 1, + [](const std::shared_ptr node) { return node->id(); }); + } else { + } + } else { + neighbors.push_back(id_); + MS_LOG(DEBUG) << "No neighbors. node_id:" << id_ << " neighbor_type:" << neighbor_type; + } + *out_neighbors = std::move(neighbors); + return Status::OK(); +} + +Status LocalNode::AddNeighbor(const std::shared_ptr &node) { + auto itr = neighbor_nodes_.find(node->type()); + if (itr != neighbor_nodes_.end()) { + itr->second.push_back(node); + } else { + neighbor_nodes_[node->type()] = {node}; + } + return Status::OK(); +} + +Status LocalNode::UpdateFeature(const std::shared_ptr &feature) { + auto itr = features_.find(feature->type()); + if (itr != features_.end()) { + RETURN_STATUS_UNEXPECTED("Feature already exists"); + } else { + features_[feature->type()] = feature; + return Status::OK(); + } +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/dataset/engine/gnn/local_node.h new file mode 100644 index 00000000000..25f24818e1f --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/local_node.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_LOCAL_NODE_H_ +#define DATASET_ENGINE_GNN_LOCAL_NODE_H_ + +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/engine/gnn/node.h" +#include "dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class LocalNode : public Node { + public: + // Constructor + // @param NodeIdType id - node id + // @param NodeType type - node type + LocalNode(NodeIdType id, NodeType type); + + ~LocalNode() = default; + + // Get the feature of a node + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) override; + + // Get the neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector *out_neighbors) override; + + // Add neighbor of node + // @param std::shared_ptr node - + // @return Status - The error code return + Status AddNeighbor(const std::shared_ptr &node) override; + + // Update feature of node + // @param std::shared_ptr feature - + // @return Status - The error code return + Status UpdateFeature(const std::shared_ptr &feature) override; + + private: + std::unordered_map> features_; + std::unordered_map>> neighbor_nodes_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_LOCAL_NODE_H_ diff --git a/mindspore/ccsrc/dataset/engine/gnn/node.h b/mindspore/ccsrc/dataset/engine/gnn/node.h new file mode 100644 index 00000000000..8e3db51d65d --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/gnn/node.h @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_GNN_NODE_H_ +#define DATASET_ENGINE_GNN_NODE_H_ + +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/engine/gnn/feature.h" + +namespace mindspore { +namespace dataset { +namespace gnn { +using NodeType = int8_t; +using NodeIdType = int32_t; + +constexpr NodeIdType kDefaultNodeId = -1; + +class Node { + public: + // Constructor + // @param NodeIdType id - node id + // @param NodeType type - node type + Node(NodeIdType id, NodeType type) : id_(id), type_(type) {} + + virtual ~Node() = default; + + // @return NodeIdType - Returned node id + NodeIdType id() const { return id_; } + + // @return NodeIdType - Returned node type + NodeType type() const { return type_; } + + // Get the feature of a node + // @param FeatureType feature_type - type of feature + // @param std::shared_ptr *out_feature - Returned feature + // @return Status - The error code return + virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr *out_feature) = 0; + + // Get the neighbors of a node + // @param NodeType neighbor_type - type of neighbor + // @param int32_t samples_num - Number of neighbors to be acquired, if -1 means all neighbors are acquired + // @param std::vector *out_neighbors - Returned neighbors id + // @return Status - The error code return + virtual Status GetNeighbors(NodeType neighbor_type, int32_t samples_num, std::vector *out_neighbors) = 0; + + // Add neighbor of node + // @param std::shared_ptr node - + // @return Status - The error code return + virtual Status AddNeighbor(const std::shared_ptr &node) = 0; + + // Update feature of node + // @param std::shared_ptr feature - + // @return Status - The error code return + virtual Status UpdateFeature(const std::shared_ptr &feature) = 0; + + protected: + NodeIdType id_; + NodeType type_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_GNN_NODE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index d1a427af276..8db7761fb85 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -114,7 +114,7 @@ class ShardReader { /// \brief aim to get columns context /// \return the columns - std::shared_ptr get_shard_column() const; + std::shared_ptr GetShardColumn() const; /// \brief get the number of shards /// \return # of shards diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index 7b3e222c9e8..fcb588fff83 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -232,7 +232,7 @@ void ShardReader::Close() { std::shared_ptr ShardReader::GetShardHeader() const { return shard_header_; } -std::shared_ptr ShardReader::get_shard_column() const { return shard_column_; } +std::shared_ptr ShardReader::GetShardColumn() const { return shard_column_; } int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); } diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 54068eb762a..93c1a6e0472 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -25,9 +25,10 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show +from .engine.graphdata import GraphData __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", - "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"] + "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip", "GraphData"] diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py new file mode 100644 index 00000000000..a4c77ef3f8d --- /dev/null +++ b/mindspore/dataset/engine/graphdata.py @@ -0,0 +1,111 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +graphdata.py supports loading graph dataset for GNN network training, +and provides operations related to graph data. +""" +import numpy as np +from mindspore._c_dataengine import Graph +from mindspore._c_dataengine import Tensor + +from .validators import check_gnn_get_all_nodes, check_gnn_get_all_neighbors, check_gnn_get_node_feature + + +class GraphData: + """ + Reads th graph dataset used for GNN training from the shared file and database. + + Args: + dataset_file (str): One of file names in dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel + (default=None). + """ + + def __init__(self, dataset_file, num_parallel_workers=None): + self._dataset_file = dataset_file + if num_parallel_workers is None: + num_parallel_workers = 1 + self._graph = Graph(dataset_file, num_parallel_workers) + + @check_gnn_get_all_nodes + def get_all_nodes(self, node_type): + """ + Get all nodes in the graph. + + Args: + node_type (int): Specify the tpye of node. + + Returns: + numpy.ndarray: array of nodes. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.get_all_nodes(0) + + Raises: + TypeError: If `node_type` is not integer. + """ + return self._graph.get_nodes(node_type, -1).as_array() + + @check_gnn_get_all_neighbors + def get_all_neighbors(self, node_list, neighbor_type): + """ + Get `neighbor_type` neighbors of the nodes in `node_list`. + + Args: + node_list (list or numpy.ndarray): The given list of nodes. + neighbor_type (int): Specify the tpye of neighbor. + + Returns: + numpy.ndarray: array of nodes. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.get_all_nodes(0) + >>> neighbors = data_graph.get_all_neighbors(nodes[0], 0) + + Raises: + TypeError: If `node_list` is not list or ndarray. + TypeError: If `neighbor_type` is not integer. + """ + return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() + + @check_gnn_get_node_feature + def get_node_feature(self, node_list, feature_types): + """ + Get `feature_types` feature of the nodes in `node_list`. + + Args: + node_list (list or numpy.ndarray): The given list of nodes. + feature_types (list or ndarray): The given list of feature types. + + Returns: + numpy.ndarray: array of features. + + Examples: + >>> import mindspore.dataset as ds + >>> data_graph = ds.GraphData('dataset_file', 2) + >>> nodes = data_graph.get_all_nodes(0) + >>> features = data_graph.get_node_feature(nodes[0], [1]) + + Raises: + TypeError: If `node_list` is not list or ndarray. + TypeError: If `feature_types` is not list or ndarray. + """ + if isinstance(node_list, list): + node_list = np.array(node_list, dtype=np.int32) + return [t.as_array() for t in self._graph.get_node_feature(Tensor(node_list), feature_types)] diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index de3987c8af3..f5005e688cd 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1032,6 +1032,7 @@ def check_textfiledataset(method): return new_method + def check_split(method): """check the input arguments of split.""" @@ -1072,3 +1073,85 @@ def check_split(method): return method(*args, **kwargs) return new_method + + +def check_list_or_ndarray(param, param_name): + if (not isinstance(param, list)) and (not hasattr(param, 'tolist')): + raise TypeError("Wrong input type for {0}, should be list, got {1}".format( + param_name, type(param))) + + +def check_gnn_get_all_nodes(method): + """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_type; required argument + check_type(param_dict.get("node_type"), 'node_type', int) + + return method(*args, **kwargs) + + return new_method + + +def check_gnn_get_all_neighbors(method): + """A wrapper that wrap a parameter checker to the GNN `get_all_neighbors` function.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_list; required argument + check_list_or_ndarray(param_dict.get("node_list"), 'node_list') + + # check neighbor_type; required argument + check_type(param_dict.get("neighbor_type"), 'neighbor_type', int) + + return method(*args, **kwargs) + + return new_method + + +def check_aligned_list(param, param_name): + """Check whether the structure of each member of the list is the same.""" + if not isinstance(param, list): + raise TypeError("Parameter {0} is not a list".format(param_name)) + membor_have_list = None + list_len = None + for membor in param: + if isinstance(membor, list): + check_aligned_list(membor, param_name) + if membor_have_list not in (None, True): + raise TypeError("The type of each member of the parameter {0} is inconsistent".format( + param_name)) + if list_len is not None and len(membor) != list_len: + raise TypeError("The size of each member of parameter {0} is inconsistent".format( + param_name)) + membor_have_list = True + list_len = len(membor) + else: + if membor_have_list not in (None, False): + raise TypeError("The type of each member of the parameter {0} is inconsistent".format( + param_name)) + membor_have_list = False + + +def check_gnn_get_node_feature(method): + """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check node_list; required argument + node_list = param_dict.get("node_list") + check_list_or_ndarray(node_list, 'node_list') + if isinstance(node_list, list): + check_aligned_list(node_list, 'node_list') + + # check feature_types; required argument + check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types') + + return method(*args, **kwargs) + + return new_method diff --git a/mindspore/mindrecord/__init__.inter b/mindspore/mindrecord/__init__.inter index ca1d50153b1..6fcabc2c5f5 100644 --- a/mindspore/mindrecord/__init__.inter +++ b/mindspore/mindrecord/__init__.inter @@ -32,6 +32,7 @@ from .mindpage import MindPage from .shardutils import SUCCESS, FAILED from .tools.cifar10_to_mr import Cifar10ToMR from .tools.cifar100_to_mr import Cifar100ToMR +from .tools.graph_map_schema import GraphMapSchema from .tools.imagenet_to_mr import ImageNetToMR from .tools.mnist_to_mr import MnistToMR diff --git a/mindspore/mindrecord/__init__.py b/mindspore/mindrecord/__init__.py index 31fb801c46b..44fc9903993 100644 --- a/mindspore/mindrecord/__init__.py +++ b/mindspore/mindrecord/__init__.py @@ -29,9 +29,10 @@ from .common.exceptions import * from .shardutils import SUCCESS, FAILED from .tools.cifar10_to_mr import Cifar10ToMR from .tools.cifar100_to_mr import Cifar100ToMR +from .tools.graph_map_schema import GraphMapSchema from .tools.imagenet_to_mr import ImageNetToMR from .tools.mnist_to_mr import MnistToMR -__all__ = ['FileWriter', 'FileReader', 'MindPage', +__all__ = ['FileWriter', 'FileReader', 'MindPage', 'GraphMapSchema', 'Cifar10ToMR', 'Cifar100ToMR', 'ImageNetToMR', 'MnistToMR', 'SUCCESS', 'FAILED'] diff --git a/mindspore/mindrecord/tools/graph_map_schema.py b/mindspore/mindrecord/tools/graph_map_schema.py new file mode 100644 index 00000000000..e131de9f650 --- /dev/null +++ b/mindspore/mindrecord/tools/graph_map_schema.py @@ -0,0 +1,145 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License") +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Graph data convert tool for MindRecord. +""" +import numpy as np + +__all__ = ['GraphMapSchema'] + + +class GraphMapSchema: + """ + Class is for transformation from graph data to MindRecord. + """ + + def __init__(self): + """ + init + """ + self.num_node_features = 0 + self.num_edge_features = 0 + self.union_schema_in_mindrecord = { + "first_id": {"type": "int64"}, + "second_id": {"type": "int64"}, + "third_id": {"type": "int64"}, + "type": {"type": "int32"}, + "attribute": {"type": "string"}, # 'n' for ndoe, 'e' for edge + "node_feature_index": {"type": "int32", "shape": [-1]}, + "edge_feature_index": {"type": "int32", "shape": [-1]} + } + + def get_schema(self): + """ + Get schema + """ + return self.union_schema_in_mindrecord + + def set_node_feature_profile(self, num_features, features_data_type, features_shape): + """ + Set node features profile + """ + if num_features != len(features_data_type) or num_features != len(features_shape): + raise ValueError("Node feature profile is not match.") + + self.num_node_features = num_features + for i in range(num_features): + k = i + 1 + field_key = 'node_feature_' + str(k) + field_value = {"type": features_data_type[i], "shape": features_shape[i]} + self.union_schema_in_mindrecord[field_key] = field_value + + def set_edge_feature_profile(self, num_features, features_data_type, features_shape): + """ + Set edge features profile + """ + if num_features != len(features_data_type) or num_features != len(features_shape): + raise ValueError("Edge feature profile is not match.") + + self.num_edge_features = num_features + for i in range(num_features): + k = i + 1 + field_key = 'edge_feature_' + str(k) + field_value = {"type": features_data_type[i], "shape": features_shape[i]} + self.union_schema_in_mindrecord[field_key] = field_value + + def transform_node(self, node): + """ + Executes transformation from node data to union format. + Args: + node(schema): node's data + Returns: + graph data with union schema + """ + node_graph = {"first_id": node["id"], "second_id": 0, "third_id": 0, "attribute": 'n', "type": node["type"], + "node_feature_index": []} + for i in range(self.num_node_features): + k = i + 1 + node_field_key = 'feature_' + str(k) + graph_field_key = 'node_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + if node_field_key in node: + node_graph["node_feature_index"].append(k) + node_graph[graph_field_key] = np.reshape(np.array(node[node_field_key], dtype=graph_field_type), [-1]) + else: + node_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1]) + + if node_graph["node_feature_index"]: + node_graph["node_feature_index"] = np.array(node_graph["node_feature_index"], dtype="int32") + else: + node_graph["node_feature_index"] = np.array([-1], dtype="int32") + + node_graph["edge_feature_index"] = np.array([-1], dtype="int32") + for i in range(self.num_edge_features): + k = i + 1 + graph_field_key = 'edge_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + node_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1]) + return node_graph + + def transform_edge(self, edge): + """ + Executes transformation from edge data to union format. + Args: + edge(schema): edge's data + Returns: + graph data with union schema + """ + edge_graph = {"first_id": edge["id"], "second_id": edge["src_id"], "third_id": edge["dst_id"], "attribute": 'e', + "type": edge["type"], "edge_feature_index": []} + + for i in range(self.num_edge_features): + k = i + 1 + edge_field_key = 'feature_' + str(k) + graph_field_key = 'edge_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + if edge_field_key in edge: + edge_graph["edge_feature_index"].append(k) + edge_graph[graph_field_key] = np.reshape(np.array(edge[edge_field_key], dtype=graph_field_type), [-1]) + else: + edge_graph[graph_field_key] = np.reshape(np.array([0], dtype=graph_field_type), [-1]) + + if edge_graph["edge_feature_index"]: + edge_graph["edge_feature_index"] = np.array(edge_graph["edge_feature_index"], dtype="int32") + else: + edge_graph["edge_feature_index"] = np.array([-1], dtype="int32") + + edge_graph["node_feature_index"] = np.array([-1], dtype="int32") + for i in range(self.num_node_features): + k = i + 1 + graph_field_key = 'node_feature_' + str(k) + graph_field_type = self.union_schema_in_mindrecord[graph_field_key]["type"] + edge_graph[graph_field_key] = np.array([0], dtype=graph_field_type) + return edge_graph diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index f80cc74a8b9..d8a3f0256e1 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -70,6 +70,7 @@ SET(DE_UT_SRCS concat_op_test.cc jieba_tokenizer_op_test.cc tokenizer_op_test.cc + gnn_graph_test.cc ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc new file mode 100644 index 00000000000..0aefffe7840 --- /dev/null +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -0,0 +1,93 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "common/common.h" +#include "gtest/gtest.h" +#include "dataset/util/status.h" +#include "dataset/engine/gnn/node.h" +#include "dataset/engine/gnn/graph_loader.h" + +using namespace mindspore::dataset; +using namespace mindspore::dataset::gnn; + +class MindDataTestGNNGraph : public UT::Common { + protected: + MindDataTestGNNGraph() = default; +}; + +TEST_F(MindDataTestGNNGraph, TestGraphLoader) { + std::string path = "data/mindrecord/testGraphData/testdata"; + GraphLoader gl(path, 4); + EXPECT_TRUE(gl.InitAndLoad().IsOk()); + NodeIdMap n_id_map; + EdgeIdMap e_id_map; + NodeTypeMap n_type_map; + EdgeTypeMap e_type_map; + NodeFeatureMap n_feature_map; + EdgeFeatureMap e_feature_map; + DefaultFeatureMap default_feature_map; + EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map, + &default_feature_map) + .IsOk()); + EXPECT_EQ(n_id_map.size(), 20); + EXPECT_EQ(e_id_map.size(), 20); + EXPECT_EQ(n_type_map[2].size(), 10); + EXPECT_EQ(n_type_map[1].size(), 10); +} + +TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { + std::string path = "data/mindrecord/testGraphData/testdata"; + Graph graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + std::vector node_info; + std::vector edge_info; + s = graph.GetMetaInfo(&node_info, &edge_info); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(node_info.size() == 2); + + std::shared_ptr nodes; + s = graph.GetNodes(node_info[1].type, -1, &nodes); + EXPECT_TRUE(s.IsOk()); + std::vector node_list; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + node_list.push_back(*itr); + if (node_list.size() >= 10) { + break; + } + } + std::shared_ptr neighbors; + s = graph.GetAllNeighbors(node_list, node_info[0].type, &neighbors); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); + TensorRow features; + s = graph.GetNodeFeature(nodes, node_info[1].feature_type, &features); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(features.size() == 3); + EXPECT_TRUE(features[0]->shape().ToString() == "<10,5>"); + EXPECT_TRUE(features[0]->ToString() == + "Tensor (shape: <10,5>, Type: int32)\n" + "[[0,1,0,0,0],[1,0,0,0,1],[0,0,1,1,0],[0,0,0,0,0],[1,1,0,1,0],[0,0,0,0,1],[0,1,0,0,0],[0,0,0,1,1],[0,1,1," + "0,0],[0,1,0,1,0]]"); + EXPECT_TRUE(features[1]->shape().ToString() == "<10>"); + EXPECT_TRUE(features[1]->ToString() == + "Tensor (shape: <10>, Type: float32)\n[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]"); + EXPECT_TRUE(features[2]->shape().ToString() == "<10>"); + EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); +} diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata new file mode 100644 index 00000000000..8978131ee1f Binary files /dev/null and b/tests/ut/data/mindrecord/testGraphData/testdata differ diff --git a/tests/ut/data/mindrecord/testGraphData/testdata.db b/tests/ut/data/mindrecord/testGraphData/testdata.db new file mode 100644 index 00000000000..f846a670090 Binary files /dev/null and b/tests/ut/data/mindrecord/testGraphData/testdata.db differ diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py new file mode 100644 index 00000000000..67aa42cb259 --- /dev/null +++ b/tests/ut/python/dataset/test_graphdata.py @@ -0,0 +1,76 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import mindspore.dataset as ds +from mindspore import log as logger + +DATASET_FILE = "../data/mindrecord/testGraphData/testdata" + + +def test_graphdata_getfullneighbor(): + g = ds.GraphData(DATASET_FILE, 2) + nodes = g.get_all_nodes(1) + assert len(nodes) is 10 + nodes_list = nodes.tolist() + neighbor = g.get_all_neighbors(nodes_list, 2) + assert neighbor.shape == (10, 6) + row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3]) + assert row_tensor[0].shape == (10, 6) + + +def test_graphdata_getnodefeature_input_check(): + g = ds.GraphData(DATASET_FILE) + with pytest.raises(TypeError): + input_list = [1, [1, 1]] + g.get_node_feature(input_list, [1]) + + with pytest.raises(TypeError): + input_list = [[1, 1], 1] + g.get_node_feature(input_list, [1]) + + with pytest.raises(TypeError): + input_list = [[1, 1], [1, 1, 1]] + g.get_node_feature(input_list, [1]) + + with pytest.raises(TypeError): + input_list = [[1, 1, 1], [1, 1]] + g.get_node_feature(input_list, [1]) + + with pytest.raises(TypeError): + input_list = [[1, 1], [1, [1, 1]]] + g.get_node_feature(input_list, [1]) + + with pytest.raises(TypeError): + input_list = [[1, 1], [[1, 1], 1]] + g.get_node_feature(input_list, [1]) + + with pytest.raises(TypeError): + input_list = [[1, 1], [1, 1]] + g.get_node_feature(input_list, 1) + + with pytest.raises(TypeError): + input_list = [[1, 1], [1, 1]] + g.get_node_feature(input_list, ["a"]) + + with pytest.raises(TypeError): + input_list = [[1, 1], [1, 1]] + g.get_node_feature(input_list, [1, "a"]) + + +if __name__ == '__main__': + test_graphdata_getfullneighbor() + logger.info('test_graphdata_getfullneighbor Ended.\n') + test_graphdata_getnodefeature_input_check() + logger.info('test_graphdata_getnodefeature_input_check Ended.\n') diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index b2453244b18..100d2d1e16c 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -89,6 +89,8 @@ def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): num_iter = 0 for item in data_set.create_dict_iterator(): logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[data]: \ + {}------------------------".format(item["data"][:10])) logger.info("-------------- item[file_name]: \ {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))