From 92039382950022b4e3cd899365dc8d73998566dd Mon Sep 17 00:00:00 2001 From: ms_yan Date: Sat, 16 Jul 2022 10:58:06 +0800 Subject: [PATCH] add chinese api for Graph API --- .../mindspore.dataset.ArgoverseDataset.rst | 31 ++ .../dataset/mindspore.dataset.Graph.rst | 280 ++++++++++++++++++ ...mindspore.dataset.InMemoryGraphDataset.rst | 46 +++ docs/api/api_python/mindspore.dataset.rst | 4 +- docs/api/api_python_en/mindspore.dataset.rst | 3 + .../mindspore/dataset/engine/graphdata.py | 113 ++++--- 6 files changed, 441 insertions(+), 36 deletions(-) create mode 100644 docs/api/api_python/dataset/mindspore.dataset.ArgoverseDataset.rst create mode 100644 docs/api/api_python/dataset/mindspore.dataset.Graph.rst create mode 100644 docs/api/api_python/dataset/mindspore.dataset.InMemoryGraphDataset.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.ArgoverseDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.ArgoverseDataset.rst new file mode 100644 index 00000000000..aaa2b8c1bcb --- /dev/null +++ b/docs/api/api_python/dataset/mindspore.dataset.ArgoverseDataset.rst @@ -0,0 +1,31 @@ +mindspore.dataset.ArgoverseDataset +================================== + +.. py:class:: mindspore.dataset.ArgoverseDataset(data_dir, column_names="graph", shuffle=None, num_parallel_workers=1, python_multiprocessing=True, perf_mode=True) + + 加载argoverse数据集并进行图(Graph)初始化。 + + Argoverse数据集是自动驾驶领域的公共数据集,当前实现的 `ArgoverseDataset` 主要用于加载argoverse数据集中运动预测(Motion Forecasting)场景的数据集,具体信息可访问官网了解: + https://www.argoverse.org/av1.html#download-link + + 参数: + - **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。 + - **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表,默认值:'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。 + - **num_parallel_workers** (int,可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式),默认值:1。 + - **shuffle** (bool,可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性( `__getitem__` )时,才可以指定该参数。默认值:None。 + - **python_multiprocessing** (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 `source` 的Python对象的计算量很大时,开启此选项可能会有较好效果。 + - **perf_mode** (bool,可选) - 遍历创建的dataset对象时获得更高性能的模式(在此过程中将调用 `__getitem__` 方法)。默认值:True,将Graph的所有数据(如边的索引、节点特征和图的特征)都作为图特征进行存储。 + + .. include:: mindspore.dataset.Dataset.add_sampler.rst + + .. include:: mindspore.dataset.Dataset.rst + + .. include:: mindspore.dataset.Dataset.b.rst + + .. include:: mindspore.dataset.Dataset.c.rst + + .. include:: mindspore.dataset.Dataset.d.rst + + .. include:: mindspore.dataset.Dataset.use_sampler.rst + + .. include:: mindspore.dataset.Dataset.zip.rst diff --git a/docs/api/api_python/dataset/mindspore.dataset.Graph.rst b/docs/api/api_python/dataset/mindspore.dataset.Graph.rst new file mode 100644 index 00000000000..bbba9d98e71 --- /dev/null +++ b/docs/api/api_python/dataset/mindspore.dataset.Graph.rst @@ -0,0 +1,280 @@ +mindspore.dataset.Graph +======================= + +.. py:class:: mindspore.dataset.Graph(edges, node_feat=None, edge_feat=None, graph_feat=None, node_type=None, edge_type=None, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True) + + 主要用于存储图的结构信息和图特征属性,并提供图采样等能力。 + + 该接口支持输入表示节点、边及其特征的NumPy数组,来进行图初始化。如果 `working_mode` 是默认的 `local` 模式,则不需要指定 `working_mode`、`hostname` 、 `port` 、 `num_client` 、 `auto_shutdown` 等输入参数。 + + 参数: + - **edges**(Union[list, numpy.ndarray]): 以COO格式表示的边,shape为 [2, num_edges]。 + - **node_feat**(dict, 可选): 节点的特征,输入数据格式应该是dict,其中key表示特征的类型,用字符串表示,比如'weight'等;value应该是shape为 [num_nodes, num_node_features] 的NumPy数组。 + - **edge_feat**(dict, 可选): 边的特征,输入数据格式应该是dict,其中key表示特征的类型,用字符串表示,比如'weight'等;value应该是shape为 [num_edges, num_edge_features] 的NumPy数组。 + - **graph_feat**(dict, 可选):附加特征,不能分配给 `node_feat` 或者 `edge_feat` ,输入数据格式应该是dict,key是特征的类型,用字符串表示; value应该是NumPy数组,其shape可以不受限制。 + - **node_type**(Union[list, numpy.ndarray], 可选): 节点的类型,每个元素都是字符串,表示每个节点的类型。如果未提供,则每个节点的默认类型为“0”。 + - **edge_type**(Union[list, numpy.ndarray], 可选): 边的类型,每个元素都是字符串,表示每条边的类型。如果未提供,则每条边的默认类型为“0”。 + - **num_parallel_workers** (int, 可选) - 读取数据的工作线程数,默认值:None,使用mindspore.dataset.config中配置的线程数。 + - **working_mode** (str, 可选) - 设置工作模式,目前支持'local'/'client'/'server',默认值:'local'。 + + - **local**:用于非分布式训练场景。 + - **client**:用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。 + - **server**:用于分布式训练场景。服务器加载数据并可供客户端使用。 + + - **hostname** (str, 可选) - 图数据集服务器的主机名。该参数仅在工作模式设置为 'client' 或 'server' 时有效,默认值:'127.0.0.1'。 + - **port** (int, 可选) - 图数据服务器的端口,取值范围为1024-65535。此参数仅当工作模式设置为 'client' 或 'server' 时有效,默认值:50051。 + - **num_client** (int, 可选) - 期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为 'server' 时有效,默认值:1。 + - **auto_shutdown** (bool, 可选) - 当工作模式设置为 'server' 时有效。当连接的客户端数量达到 `num_client` ,且没有客户端正在连接时,服务器将自动退出,默认值:True。 + + 异常: + - **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。 + - **ValueError** - `working_mode` 参数取值不为'local', 'client' 或 'server'。 + - **TypeError** - `hostname` 参数类型错误。 + - **ValueError** - `port` 参数不在范围[1024, 65535]内。 + - **ValueError** - `num_client` 参数不在范围[1, 255]内。 + + .. py:method:: get_all_edges(edge_type) + + 获取图的所有边。 + + 参数: + - **edge_type** (str) - 指定边的类型,Graph初始化未指定`edge_type`时,默认值为'0'。详见 `加载图数据集 `_ 。 + + 返回: + numpy.ndarray,包含边的数组。 + + 异常: + - **TypeError** - 参数 `edge_type` 的类型不是string类型。 + + .. py:method:: get_all_neighbors(node_list, neighbor_type, output_format=OutputFormat.NORMAL) + + 获取 `node_list` 所有节点的相邻节点,以 `neighbor_type` 类型返回。格式的定义参见以下示例:1表示两个节点之间连接,0表示不连接。 + + .. list-table:: 邻接矩阵 + :widths: 20 20 20 20 20 + :header-rows: 1 + + * - + - 0 + - 1 + - 2 + - 3 + * - 0 + - 0 + - 1 + - 0 + - 0 + * - 1 + - 0 + - 0 + - 1 + - 0 + * - 2 + - 1 + - 0 + - 0 + - 1 + * - 3 + - 1 + - 0 + - 0 + - 0 + + .. list-table:: 普通格式 + :widths: 20 20 20 20 20 + :header-rows: 1 + + * - src + - 0 + - 1 + - 2 + - 3 + * - dst_0 + - 1 + - 2 + - 0 + - 1 + * - dst_1 + - -1 + - -1 + - 3 + - -1 + + .. list-table:: COO格式 + :widths: 20 20 20 20 20 20 + :header-rows: 1 + + * - src + - 0 + - 1 + - 2 + - 2 + - 3 + * - dst + - 1 + - 2 + - 0 + - 3 + - 1 + + .. list-table:: CSR格式 + :widths: 40 20 20 20 20 20 + :header-rows: 1 + + * - offsetTable + - 0 + - 1 + - 2 + - 4 + - + * - dstTable + - 1 + - 2 + - 0 + - 3 + - 1 + + 参数: + - **node_list** (Union[list, numpy.ndarray]) - 给定的节点列表。 + - **neighbor_type** (str) - 指定相邻节点的类型。 + - **output_format** (OutputFormat, 可选) - 输出存储格式,默认值:mindspore.dataset.OutputFormat.NORMAL,取值范围:[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。 + + 返回: + 对于普通格式或COO格式,将返回numpy.ndarray类型的数组表示相邻节点。如果指定了CSR格式,将返回两个numpy.ndarray数组,第一个表示偏移表,第二个表示相邻节点。 + + 异常: + - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `neighbor_type` 的类型不是string类型。 + + .. py:method:: get_all_nodes(node_type) + + 获取图中的所有节点。 + + 参数: + - **node_type** (str) - 指定节点的类型。Graph初始化未指定`edge_type`时,默认值为'0'。详见 `加载图数据集 `_ 。 + + 返回: + numpy.ndarray,包含节点的数组。 + + 异常: + - **TypeError** - 参数 `node_type` 的类型不是string类型。 + + .. py:method:: get_edge_feature(edge_list, feature_types) + + 获取 `edge_list` 列表中边的特征,以 `feature_types` 类型返回。 + + 参数: + - **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。 + - **feature_types** (Union[list, numpy.ndarray]) - 包含给定特征类型的列表,列表中每个元素是string类型。 + + 返回: + numpy.ndarray,包含特征的数组。 + + 异常: + - **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。 + + .. py:method:: get_edges_from_nodes(node_list) + + 从节点获取边。 + + 参数: + - **node_list** (Union[list[tuple], numpy.ndarray]) - 含一个或多个图节点ID对的列表。 + + 返回: + numpy.ndarray,含一个或多个边ID的数组。 + + 异常: + - **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。 + + .. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type) + + 获取 `node_list` 列表中节所有点的负样本相邻节点,以 `neg_neighbor_type` 类型返回。 + + 参数: + - **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。 + - **neg_neighbor_num** (int) - 采样的相邻节点数量。 + - **neg_neighbor_type** (str) - 指定负样本相邻节点的类型。 + + 返回: + numpy.ndarray,包含相邻节点的数组。 + + 异常: + - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `neg_neighbor_num` 的类型不为整型。 + - **TypeError** - 参数 `neg_neighbor_type` 的类型不是string类型。 + + .. py:method:: get_node_feature(node_list, feature_types) + + 获取 `node_list` 中节点的特征,以 `feature_types` 类型返回。 + + 参数: + - **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。 + - **feature_types** (Union[list, numpy.ndarray]) - 指定特征的类型,类型列表中每个元素应该是string类型。 + + 返回: + numpy.ndarray,包含特征的数组。 + + 异常: + - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。 + + + .. py:method:: get_nodes_from_edges(edge_list) + + 从图中的边获取节点。 + + 参数: + - **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。 + + 返回: + numpy.ndarray,包含节点的数组。 + + 异常: + - **TypeError** - 参数 `edge_list` 不为列表或ndarray。 + + .. py:method:: get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM) + + 获取已采样相邻节点信息。此API支持多跳相邻节点采样。即将上一次采样结果作为下一跳采样的输入,最多允许6跳。采样结果平铺成列表,格式为[input node, 1-hop sampling result, 2-hop samling result ...] + + 参数: + - **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。 + - **neighbor_nums** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点数。 + - **neighbor_types** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点类型。 + - **strategy** (SamplingStrategy, 可选) - 采样策略,默认值:mindspore.dataset.SamplingStrategy.RANDOM。取值范围:[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。 + + - **SamplingStrategy.RANDOM**:随机抽样,带放回采样。 + - **SamplingStrategy.EDGE_WEIGHT**:以边缘权重为概率进行采样。 + + 返回: + numpy.ndarray,包含相邻节点的数组。 + + 异常: + - **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `neighbor_nums` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `neighbor_types` 的类型不为列表或numpy.ndarray。 + + .. py:method:: graph_info() + + 获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。 + + 返回: + dict,图的元信息。键为 `node_num` 、 `node_type` 、 `node_feature_type` 、 `edge_num` 、 `edge_type` 、`edge_feature_type` 和 `graph_feature_type`。 + + .. py:method:: random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1) + + 在节点中的随机游走。 + + 参数: + - **target_nodes** (list[int]) - 随机游走中的起始节点列表。 + - **meta_path** (list[int]) - 每个步长的节点类型。 + - **step_home_param** (float, 可选) - 返回 `node2vec算法 `_ 中的超参,默认值:1.0。 + - **step_away_param** (float, 可选) - `node2vec算法 `_ 中的in和out超参,默认值:1.0。 + - **default_node** (int, 可选) - 如果找不到更多相邻节点,则为默认节点,默认值:-1,表示不给定节点。 + + 返回: + numpy.ndarray,包含节点的数组。 + + 异常: + - **TypeError** - 参数 `target_nodes` 的类型不为列表或numpy.ndarray。 + - **TypeError** - 参数 `meta_path` 的类型不为列表或numpy.ndarray。 diff --git a/docs/api/api_python/dataset/mindspore.dataset.InMemoryGraphDataset.rst b/docs/api/api_python/dataset/mindspore.dataset.InMemoryGraphDataset.rst new file mode 100644 index 00000000000..362f0f6a764 --- /dev/null +++ b/docs/api/api_python/dataset/mindspore.dataset.InMemoryGraphDataset.rst @@ -0,0 +1,46 @@ +mindspore.dataset.InMemoryGraphDataset +====================================== + +.. py:class:: mindspore.dataset.InMemoryGraphDataset(data_dir, save_dir="./processed", column_names="graph", num_samples=None, num_parallel_workers=1, shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6) + + 用于将图数据加载到内存中的Dataset基类。 + + 建议通过继承这个基类来实现相应的Dataset,并实现对应的方法,比如'process'、'save'和'load'。 + + 参数: + - **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。 + - **save_dir** (str) - 保存处理后得到的数据集的相对目录,该目录位于 `data_dir` 下面。 + - **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表,默认值:'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。 + - **num_samples** (int,可选) - 指定从数据集中读取的样本数,默认值:None,读取全部样本。 + - **num_parallel_workers** (int,可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式),默认值:1。 + - **shuffle** (bool,可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性( `__getitem__` )时,才可以指定该参数。默认值:None。 + - **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数,默认值:None。指定此参数后, `num_samples` 表示每个分片的最大样本数。 + - **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号,默认值:None。只有当指定了 `num_shards` 时才能指定此参数。 + - **python_multiprocessing** (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 `source` 的Python对象的计算量很大时,开启此选项可能会有较好效果。 + - **max_rowsize** (int,可选) - 指定在多进程之间复制数据时,共享内存分配的最大空间,默认值:6,单位为MB。仅当参数 `python_multiprocessing` 设为True时,此参数才会生效。 + + .. py:method:: process() + + 与原始数据集相关的处理方法,建议在自定义的Dataset中重写此方法。 + + .. py:method:: save() + + 将经过 `process` 函数处理后的数据以 numpy.npz 格式保存到磁盘中,也可以在自己实现的Dataset类中自己实现这个方法。 + + .. py:method:: load() + + 从给定(处理好的)路径加载数据,也可以在自己实现的Dataset类中实现这个方法。 + + .. include:: mindspore.dataset.Dataset.add_sampler.rst + + .. include:: mindspore.dataset.Dataset.rst + + .. include:: mindspore.dataset.Dataset.b.rst + + .. include:: mindspore.dataset.Dataset.c.rst + + .. include:: mindspore.dataset.Dataset.d.rst + + .. include:: mindspore.dataset.Dataset.use_sampler.rst + + .. include:: mindspore.dataset.Dataset.zip.rst diff --git a/docs/api/api_python/mindspore.dataset.rst b/docs/api/api_python/mindspore.dataset.rst index 1ac2038f981..7411754e69b 100644 --- a/docs/api/api_python/mindspore.dataset.rst +++ b/docs/api/api_python/mindspore.dataset.rst @@ -131,8 +131,10 @@ mindspore.dataset .. mscnautosummary:: :toctree: dataset + mindspore.dataset.ArgoverseDataset + mindspore.dataset.Graph mindspore.dataset.GraphData - + mindspore.dataset.InMemoryGraphDataset 采样器 ------- diff --git a/docs/api/api_python_en/mindspore.dataset.rst b/docs/api/api_python_en/mindspore.dataset.rst index f5a929afa93..565875b221a 100644 --- a/docs/api/api_python_en/mindspore.dataset.rst +++ b/docs/api/api_python_en/mindspore.dataset.rst @@ -113,7 +113,10 @@ Graph :nosignatures: :template: classtemplate_inherited.rst + mindspore.dataset.ArgoverseDataset + mindspore.dataset.Graph mindspore.dataset.GraphData + mindspore.dataset.InMemoryGraphDataset Sampler -------- diff --git a/mindspore/python/mindspore/dataset/engine/graphdata.py b/mindspore/python/mindspore/dataset/engine/graphdata.py index fb35a535a3f..0f4518c0d6b 100644 --- a/mindspore/python/mindspore/dataset/engine/graphdata.py +++ b/mindspore/python/mindspore/dataset/engine/graphdata.py @@ -22,7 +22,6 @@ import random import time from enum import IntEnum import numpy as np -import pandas as pd from mindspore._c_dataengine import GraphDataClient from mindspore._c_dataengine import GraphDataServer from mindspore._c_dataengine import Tensor @@ -138,22 +137,18 @@ class GraphData: if working_mode in ['local', 'client']: self._graph_data = GraphDataClient(self.data_format, dataset_file, num_parallel_workers, working_mode, hostname, port) - atexit.register(self.stop) + atexit.register(self._stop) if working_mode == 'server': self._graph_data = GraphDataServer( self.data_format, dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) - atexit.register(self.stop) + atexit.register(self._stop) try: while self._graph_data.is_stopped() is not True: time.sleep(1) except KeyboardInterrupt: raise Exception("Graph data server receives KeyboardInterrupt.") - def stop(self): - """Stop GraphDataClient or GraphDataServer.""" - self._graph_data.stop() - @check_gnn_get_all_nodes def get_all_nodes(self, node_type): """ @@ -530,23 +525,30 @@ class GraphData: return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param, default_node).as_array() + def _stop(self): + """Stop GraphDataClient or GraphDataServer.""" + self._graph_data.stop() + class Graph(GraphData): """ - A graph object for storing Graph structure and feature data. + A graph object for storing Graph structure and feature data, and provide capabilities such as graph sampling. - This class supports init graph With input numpy array data, which represent edge, node and its features. + This class supports init graph With input numpy array data, which represent node, edge and its features. If working mode is `local`, there is no need to specify input arguments like `working_mode`, `hostname`, `port`, `num_client`, `auto_shutdown`. Args: edges(Union[list, numpy.ndarray]): edges of graph in COO format with shape [2, num_edges]. - node_feat(dict, optional): feature of nodes, key is feature type, value should be numpy.array with shape - [num_nodes, num_node_features], feature type should be string, like 'weight' etc. - edge_feat(dict, optional): feature of edges, key is feature type, value should be numpy.array with shape - [num_edges, num_edge_features], feature type should be string, like 'weight' etc. - graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, key is - feature type, value should be numpy.array. + node_feat(dict, optional): feature of nodes, input data format should be dict, key is feature type, which is + represented with string like 'weight' etc, value should be numpy.array with shape + [num_nodes, num_node_features]. + edge_feat(dict, optional): feature of edges, input data format should be dict, key is feature type, which is + represented with string like 'weight' etc, value should be numpy.array with shape + [num_edges, num_edge_features]. + graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, input data + format should be dict, key is feature type, which is represented with string, value should be numpy.array, + its shape is not restricted. node_type(Union[list, numpy.ndarray], optional): type of nodes, each element should be string which represent type of corresponding node. If not provided, default type for each node is '0'. edge_type(Union[list, numpy.ndarray], optional): type of edges, each element should be string which represent @@ -630,23 +632,19 @@ class Graph(GraphData): self._graph_data = GraphDataClient(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat, node_type, edge_type, num_parallel_workers, working_mode, hostname, port) - atexit.register(self.stop) + atexit.register(self._stop) if working_mode == 'server': self._graph_data = GraphDataServer(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat, node_type, edge_type, num_parallel_workers, hostname, port, num_client, auto_shutdown) - atexit.register(self.stop) + atexit.register(self._stop) try: while self._graph_data.is_stopped() is not True: time.sleep(1) except KeyboardInterrupt: raise Exception("Graph data server receives KeyboardInterrupt.") - def stop(self): - """Stop GraphDataClient or GraphDataServer.""" - self._graph_data.stop() - @check_gnn_get_all_nodes def get_all_nodes(self, node_type): """ @@ -679,7 +677,8 @@ class Graph(GraphData): Get all edges in the graph. Args: - edge_type (int): Specify the type of edge. + edge_type (str): Specify the type of edge, default edge_type is '0' when init graph without specify + edge_type. Returns: numpy.ndarray, array of edges. @@ -1037,7 +1036,7 @@ class Graph(GraphData): Returns: dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num, - node_feature_type and edge_feature_type. + node_feature_type, edge_feature_type and graph_feature_type. """ if self._working_mode == 'server': raise Exception("This method is not supported when working mode is server.") @@ -1223,11 +1222,33 @@ class _UsersDatasetTemplate: class InMemoryGraphDataset(GeneratorDataset): """ - The basic Dataset for loading graph into memory. - Recommended to inherit this class, and implement your own method like 'process', 'save' and 'load'. + Basic Dataset for loading graph into memory. + + Recommended to Implement your own dataset with inheriting this class, and implement your own method like 'process', + 'save' and 'load'. + + Args: + data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in + `process` method. + save_dir (str): relative directory for saving processed dataset, this directory is under `data_dir`. + column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset, + num of column name should be equal to num of item in return data when implement method like `__getitem__`. + num_samples (int, optional): The number of samples to be included in the dataset (default=None, all samples). + num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). + shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. + (default=None, expected order behavior shown in the table). + num_shards (int, optional): Number of shards that the dataset will be divided into (default=None). + Random accessible input is required. When this argument is specified, `num_samples` reflects the max + sample number of per shard. + shard_id (int, optional): The shard ID within `num_shards` (default=None). This argument must be specified only + when num_shards is also specified. Random accessible input is required. + python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This + option could be beneficial if the Python operation is computational heavy (default=True). + max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy + data between processes. This is only used if python_multiprocessing is set to True (default 6 MB). """ - def __init__(self, data_dir, column_names="graph", save_dir="./processed", num_parallel_workers=1, + def __init__(self, data_dir, save_dir="./processed", column_names="graph", num_samples=None, num_parallel_workers=1, shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6): self.graphs = [] self.data_dir = data_dir @@ -1244,25 +1265,25 @@ class InMemoryGraphDataset(GeneratorDataset): setattr(source, k, v) for k, v in self.__class__.__dict__.items(): setattr(source.__class__, k, getattr(self.__class__, k)) - super().__init__(source, column_names=column_names, num_parallel_workers=num_parallel_workers, shuffle=shuffle, - num_shards=num_shards, shard_id=shard_id, python_multiprocessing=python_multiprocessing, - max_rowsize=max_rowsize) + super().__init__(source, column_names=column_names, num_samples=num_samples, + num_parallel_workers=num_parallel_workers, shuffle=shuffle, num_shards=num_shards, + shard_id=shard_id, python_multiprocessing=python_multiprocessing, max_rowsize=max_rowsize) def process(self): """ - Override this method in your our dataset class. + Process method based on origin dataset, override this method in your our dataset class. """ raise NotImplementedError("'process' method should be implemented in your own logic.") def save(self): """ - Override this method in your our dataset class. + Save processed data into disk in numpy.npz format, you can also override this method in your dataset class. """ save_graphs(self.processed_path, self.graphs) def load(self): """ - Override this method in your our dataset class. + Load data from given(processed) path, you can also override this method in your dataset class. """ self.graphs = load_graphs(self.processed_path, num_parallel_workers=1) @@ -1280,14 +1301,32 @@ class InMemoryGraphDataset(GeneratorDataset): class ArgoverseDataset(InMemoryGraphDataset): """ Load argoverse dataset and create graph. + + Here argoverse dataset is public dataset for autonomous driving, current implement `ArgoverseDataset` is mainly for + loading Motion Forecasting Dataset in argoverse dataset, recommend to visit official website for more detail: + https://www.argoverse.org/av1.html#download-link. + + Args: + data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in + `process` method. + column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset, + num of column name should be equal to num of item in return data when implement method like `__getitem__`. + num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). + shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. + (default=None, expected order behavior shown in the table). + python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This + option could be beneficial if the Python operation is computational heavy (default=True). + perf_mode(bool, optional): mode for obtaining higher performance when iterate created dataset(will call + `__getitem__` method in this process). Default True, will save all the data in graph + (like edge index, node feature and graph feature) into graph feature. """ - def __init__(self, data_dir, column_names="graph", shuffle=None, num_parallel_workers=1, + def __init__(self, data_dir, column_names="graph", num_parallel_workers=1, shuffle=None, python_multiprocessing=True, perf_mode=True): # For high performance, here we store edge_index into graph_feature directly self.perf_mode = perf_mode - super().__init__(data_dir, column_names, shuffle=shuffle, num_parallel_workers=num_parallel_workers, - python_multiprocessing=python_multiprocessing) + super().__init__(data_dir=data_dir, column_names=column_names, shuffle=shuffle, + num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing) def __getitem__(self, index): graph = self.graphs[index] @@ -1312,6 +1351,10 @@ class ArgoverseDataset(InMemoryGraphDataset): """ process method mainly refers to: https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py """ + try: + import pandas as pd + except ImportError: + raise ImportError("Import pandas failed, recommend to install pandas with pip.") def get_edge_full_connection(node_num, start_index=0): """