!38227 add chinese api for Graph API

Merge pull request !38227 from ms_yan/CN_api_graph
This commit is contained in:
i-robot 2022-07-18 12:47:59 +00:00 committed by Gitee
commit eb4030b089
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 441 additions and 36 deletions

View File

@ -0,0 +1,31 @@
mindspore.dataset.ArgoverseDataset
==================================
.. py:class:: mindspore.dataset.ArgoverseDataset(data_dir, column_names="graph", shuffle=None, num_parallel_workers=1, python_multiprocessing=True, perf_mode=True)
加载argoverse数据集并进行图Graph初始化。
Argoverse数据集是自动驾驶领域的公共数据集当前实现的 `ArgoverseDataset` 主要用于加载argoverse数据集中运动预测Motion Forecasting场景的数据集具体信息可访问官网了解
https://www.argoverse.org/av1.html#download-link
参数:
- **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表默认值'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。
- **num_parallel_workers** (int可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式默认值1。
- **shuffle** (bool可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性 `__getitem__` 才可以指定该参数。默认值None。
- **python_multiprocessing** (bool可选) - 启用Python多进程模式加速运算默认值True。当传入 `source` 的Python对象的计算量很大时开启此选项可能会有较好效果。
- **perf_mode** (bool可选) - 遍历创建的dataset对象时获得更高性能的模式在此过程中将调用 `__getitem__` 方法。默认值True将Graph的所有数据如边的索引、节点特征和图的特征都作为图特征进行存储。
.. include:: mindspore.dataset.Dataset.add_sampler.rst
.. include:: mindspore.dataset.Dataset.rst
.. include:: mindspore.dataset.Dataset.b.rst
.. include:: mindspore.dataset.Dataset.c.rst
.. include:: mindspore.dataset.Dataset.d.rst
.. include:: mindspore.dataset.Dataset.use_sampler.rst
.. include:: mindspore.dataset.Dataset.zip.rst

View File

@ -0,0 +1,280 @@
mindspore.dataset.Graph
=======================
.. py:class:: mindspore.dataset.Graph(edges, node_feat=None, edge_feat=None, graph_feat=None, node_type=None, edge_type=None, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
主要用于存储图的结构信息和图特征属性,并提供图采样等能力。
该接口支持输入表示节点、边及其特征的NumPy数组来进行图初始化。如果 `working_mode` 是默认的 `local` 模式,则不需要指定 `working_mode``hostname``port``num_client``auto_shutdown` 等输入参数。
参数:
- **edges**(Union[list, numpy.ndarray]): 以COO格式表示的边shape为 [2, num_edges]。
- **node_feat**(dict, 可选): 节点的特征输入数据格式应该是dict其中key表示特征的类型用字符串表示比如'weight'等value应该是shape为 [num_nodes, num_node_features] 的NumPy数组。
- **edge_feat**(dict, 可选): 边的特征输入数据格式应该是dict其中key表示特征的类型用字符串表示比如'weight'等value应该是shape为 [num_edges, num_edge_features] 的NumPy数组。
- **graph_feat**(dict, 可选):附加特征,不能分配给 `node_feat` 或者 `edge_feat` 输入数据格式应该是dictkey是特征的类型用字符串表示; value应该是NumPy数组其shape可以不受限制。
- **node_type**(Union[list, numpy.ndarray], 可选): 节点的类型每个元素都是字符串表示每个节点的类型。如果未提供则每个节点的默认类型为“0”。
- **edge_type**(Union[list, numpy.ndarray], 可选): 边的类型每个元素都是字符串表示每条边的类型。如果未提供则每条边的默认类型为“0”。
- **num_parallel_workers** (int, 可选) - 读取数据的工作线程数默认值None使用mindspore.dataset.config中配置的线程数。
- **working_mode** (str, 可选) - 设置工作模式,目前支持'local'/'client'/'server',默认值:'local'。
- **local**:用于非分布式训练场景。
- **client**:用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
- **server**:用于分布式训练场景。服务器加载数据并可供客户端使用。
- **hostname** (str, 可选) - 图数据集服务器的主机名。该参数仅在工作模式设置为 'client' 或 'server' 时有效,默认值:'127.0.0.1'。
- **port** (int, 可选) - 图数据服务器的端口取值范围为1024-65535。此参数仅当工作模式设置为 'client' 或 'server' 时有效默认值50051。
- **num_client** (int, 可选) - 期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为 'server' 时有效默认值1。
- **auto_shutdown** (bool, 可选) - 当工作模式设置为 'server' 时有效。当连接的客户端数量达到 `num_client` 且没有客户端正在连接时服务器将自动退出默认值True。
异常:
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
- **ValueError** - `working_mode` 参数取值不为'local', 'client' 或 'server'。
- **TypeError** - `hostname` 参数类型错误。
- **ValueError** - `port` 参数不在范围[1024, 65535]内。
- **ValueError** - `num_client` 参数不在范围[1, 255]内。
.. py:method:: get_all_edges(edge_type)
获取图的所有边。
参数:
- **edge_type** (str) - 指定边的类型Graph初始化未指定`edge_type`时,默认值为'0'。详见 `加载图数据集 <https://www.mindspore.cn/tutorials/zh-CN/master/advanced/dataset/augment_graph_data.html>`_
返回:
numpy.ndarray包含边的数组。
异常:
- **TypeError** - 参数 `edge_type` 的类型不是string类型。
.. py:method:: get_all_neighbors(node_list, neighbor_type, output_format=OutputFormat.NORMAL)
获取 `node_list` 所有节点的相邻节点,以 `neighbor_type` 类型返回。格式的定义参见以下示例1表示两个节点之间连接0表示不连接。
.. list-table:: 邻接矩阵
:widths: 20 20 20 20 20
:header-rows: 1
* -
- 0
- 1
- 2
- 3
* - 0
- 0
- 1
- 0
- 0
* - 1
- 0
- 0
- 1
- 0
* - 2
- 1
- 0
- 0
- 1
* - 3
- 1
- 0
- 0
- 0
.. list-table:: 普通格式
:widths: 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 3
* - dst_0
- 1
- 2
- 0
- 1
* - dst_1
- -1
- -1
- 3
- -1
.. list-table:: COO格式
:widths: 20 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 2
- 3
* - dst
- 1
- 2
- 0
- 3
- 1
.. list-table:: CSR格式
:widths: 40 20 20 20 20 20
:header-rows: 1
* - offsetTable
- 0
- 1
- 2
- 4
-
* - dstTable
- 1
- 2
- 0
- 3
- 1
参数:
- **node_list** (Union[list, numpy.ndarray]) - 给定的节点列表。
- **neighbor_type** (str) - 指定相邻节点的类型。
- **output_format** (OutputFormat, 可选) - 输出存储格式默认值mindspore.dataset.OutputFormat.NORMAL取值范围[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
返回:
对于普通格式或COO格式将返回numpy.ndarray类型的数组表示相邻节点。如果指定了CSR格式将返回两个numpy.ndarray数组第一个表示偏移表第二个表示相邻节点。
异常:
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neighbor_type` 的类型不是string类型。
.. py:method:: get_all_nodes(node_type)
获取图中的所有节点。
参数:
- **node_type** (str) - 指定节点的类型。Graph初始化未指定`edge_type`时,默认值为'0'。详见 `加载图数据集 <https://www.mindspore.cn/tutorials/zh-CN/master/advanced/dataset/augment_graph_data.html>`_
返回:
numpy.ndarray包含节点的数组。
异常:
- **TypeError** - 参数 `node_type` 的类型不是string类型。
.. py:method:: get_edge_feature(edge_list, feature_types)
获取 `edge_list` 列表中边的特征,以 `feature_types` 类型返回。
参数:
- **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
- **feature_types** (Union[list, numpy.ndarray]) - 包含给定特征类型的列表列表中每个元素是string类型。
返回:
numpy.ndarray包含特征的数组。
异常:
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
.. py:method:: get_edges_from_nodes(node_list)
从节点获取边。
参数:
- **node_list** (Union[list[tuple], numpy.ndarray]) - 含一个或多个图节点ID对的列表。
返回:
numpy.ndarray含一个或多个边ID的数组。
异常:
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
.. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
获取 `node_list` 列表中节所有点的负样本相邻节点,以 `neg_neighbor_type` 类型返回。
参数:
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
- **neg_neighbor_num** (int) - 采样的相邻节点数量。
- **neg_neighbor_type** (str) - 指定负样本相邻节点的类型。
返回:
numpy.ndarray包含相邻节点的数组。
异常:
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neg_neighbor_num` 的类型不为整型。
- **TypeError** - 参数 `neg_neighbor_type` 的类型不是string类型。
.. py:method:: get_node_feature(node_list, feature_types)
获取 `node_list` 中节点的特征,以 `feature_types` 类型返回。
参数:
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
- **feature_types** (Union[list, numpy.ndarray]) - 指定特征的类型类型列表中每个元素应该是string类型。
返回:
numpy.ndarray包含特征的数组。
异常:
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
.. py:method:: get_nodes_from_edges(edge_list)
从图中的边获取节点。
参数:
- **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
返回:
numpy.ndarray包含节点的数组。
异常:
- **TypeError** - 参数 `edge_list` 不为列表或ndarray。
.. py:method:: get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM)
获取已采样相邻节点信息。此API支持多跳相邻节点采样。即将上一次采样结果作为下一跳采样的输入最多允许6跳。采样结果平铺成列表格式为[input node, 1-hop sampling result, 2-hop samling result ...]
参数:
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
- **neighbor_nums** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点数。
- **neighbor_types** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点类型。
- **strategy** (SamplingStrategy, 可选) - 采样策略默认值mindspore.dataset.SamplingStrategy.RANDOM。取值范围[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
- **SamplingStrategy.RANDOM**:随机抽样,带放回采样。
- **SamplingStrategy.EDGE_WEIGHT**:以边缘权重为概率进行采样。
返回:
numpy.ndarray包含相邻节点的数组。
异常:
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neighbor_nums` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neighbor_types` 的类型不为列表或numpy.ndarray。
.. py:method:: graph_info()
获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
返回:
dict图的元信息。键为 `node_num``node_type``node_feature_type``edge_num``edge_type``edge_feature_type``graph_feature_type`
.. py:method:: random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
在节点中的随机游走。
参数:
- **target_nodes** (list[int]) - 随机游走中的起始节点列表。
- **meta_path** (list[int]) - 每个步长的节点类型。
- **step_home_param** (float, 可选) - 返回 `node2vec算法 <https://www.kdd.org/kdd2016/papers/files/rfp0218-groverA.pdf>`_ 中的超参默认值1.0。
- **step_away_param** (float, 可选) - `node2vec算法 <https://www.kdd.org/kdd2016/papers/files/rfp0218-groverA.pdf>`_ 中的in和out超参默认值1.0。
- **default_node** (int, 可选) - 如果找不到更多相邻节点,则为默认节点,默认值:-1表示不给定节点。
返回:
numpy.ndarray包含节点的数组。
异常:
- **TypeError** - 参数 `target_nodes` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `meta_path` 的类型不为列表或numpy.ndarray。

View File

@ -0,0 +1,46 @@
mindspore.dataset.InMemoryGraphDataset
======================================
.. py:class:: mindspore.dataset.InMemoryGraphDataset(data_dir, save_dir="./processed", column_names="graph", num_samples=None, num_parallel_workers=1, shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6)
用于将图数据加载到内存中的Dataset基类。
建议通过继承这个基类来实现相应的Dataset并实现对应的方法比如'process'、'save'和'load'。
参数:
- **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。
- **save_dir** (str) - 保存处理后得到的数据集的相对目录,该目录位于 `data_dir` 下面。
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表默认值'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。
- **num_samples** (int可选) - 指定从数据集中读取的样本数默认值None读取全部样本。
- **num_parallel_workers** (int可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式默认值1。
- **shuffle** (bool可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性 `__getitem__` 才可以指定该参数。默认值None。
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数默认值None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **python_multiprocessing** (bool可选) - 启用Python多进程模式加速运算默认值True。当传入 `source` 的Python对象的计算量很大时开启此选项可能会有较好效果。
- **max_rowsize** (int可选) - 指定在多进程之间复制数据时共享内存分配的最大空间默认值6单位为MB。仅当参数 `python_multiprocessing` 设为True时此参数才会生效。
.. py:method:: process()
与原始数据集相关的处理方法建议在自定义的Dataset中重写此方法。
.. py:method:: save()
将经过 `process` 函数处理后的数据以 numpy.npz 格式保存到磁盘中也可以在自己实现的Dataset类中自己实现这个方法。
.. py:method:: load()
从给定处理好的路径加载数据也可以在自己实现的Dataset类中实现这个方法。
.. include:: mindspore.dataset.Dataset.add_sampler.rst
.. include:: mindspore.dataset.Dataset.rst
.. include:: mindspore.dataset.Dataset.b.rst
.. include:: mindspore.dataset.Dataset.c.rst
.. include:: mindspore.dataset.Dataset.d.rst
.. include:: mindspore.dataset.Dataset.use_sampler.rst
.. include:: mindspore.dataset.Dataset.zip.rst

View File

@ -131,8 +131,10 @@ mindspore.dataset
.. mscnautosummary::
:toctree: dataset
mindspore.dataset.ArgoverseDataset
mindspore.dataset.Graph
mindspore.dataset.GraphData
mindspore.dataset.InMemoryGraphDataset
采样器
-------

View File

@ -113,7 +113,10 @@ Graph
:nosignatures:
:template: classtemplate_inherited.rst
mindspore.dataset.ArgoverseDataset
mindspore.dataset.Graph
mindspore.dataset.GraphData
mindspore.dataset.InMemoryGraphDataset
Sampler
--------

View File

@ -22,7 +22,6 @@ import random
import time
from enum import IntEnum
import numpy as np
import pandas as pd
from mindspore._c_dataengine import GraphDataClient
from mindspore._c_dataengine import GraphDataServer
from mindspore._c_dataengine import Tensor
@ -138,22 +137,18 @@ class GraphData:
if working_mode in ['local', 'client']:
self._graph_data = GraphDataClient(self.data_format, dataset_file, num_parallel_workers, working_mode,
hostname, port)
atexit.register(self.stop)
atexit.register(self._stop)
if working_mode == 'server':
self._graph_data = GraphDataServer(
self.data_format, dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
atexit.register(self.stop)
atexit.register(self._stop)
try:
while self._graph_data.is_stopped() is not True:
time.sleep(1)
except KeyboardInterrupt:
raise Exception("Graph data server receives KeyboardInterrupt.")
def stop(self):
"""Stop GraphDataClient or GraphDataServer."""
self._graph_data.stop()
@check_gnn_get_all_nodes
def get_all_nodes(self, node_type):
"""
@ -530,23 +525,30 @@ class GraphData:
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array()
def _stop(self):
"""Stop GraphDataClient or GraphDataServer."""
self._graph_data.stop()
class Graph(GraphData):
"""
A graph object for storing Graph structure and feature data.
A graph object for storing Graph structure and feature data, and provide capabilities such as graph sampling.
This class supports init graph With input numpy array data, which represent edge, node and its features.
This class supports init graph With input numpy array data, which represent node, edge and its features.
If working mode is `local`, there is no need to specify input arguments like `working_mode`, `hostname`, `port`,
`num_client`, `auto_shutdown`.
Args:
edges(Union[list, numpy.ndarray]): edges of graph in COO format with shape [2, num_edges].
node_feat(dict, optional): feature of nodes, key is feature type, value should be numpy.array with shape
[num_nodes, num_node_features], feature type should be string, like 'weight' etc.
edge_feat(dict, optional): feature of edges, key is feature type, value should be numpy.array with shape
[num_edges, num_edge_features], feature type should be string, like 'weight' etc.
graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, key is
feature type, value should be numpy.array.
node_feat(dict, optional): feature of nodes, input data format should be dict, key is feature type, which is
represented with string like 'weight' etc, value should be numpy.array with shape
[num_nodes, num_node_features].
edge_feat(dict, optional): feature of edges, input data format should be dict, key is feature type, which is
represented with string like 'weight' etc, value should be numpy.array with shape
[num_edges, num_edge_features].
graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, input data
format should be dict, key is feature type, which is represented with string, value should be numpy.array,
its shape is not restricted.
node_type(Union[list, numpy.ndarray], optional): type of nodes, each element should be string which represent
type of corresponding node. If not provided, default type for each node is '0'.
edge_type(Union[list, numpy.ndarray], optional): type of edges, each element should be string which represent
@ -630,23 +632,19 @@ class Graph(GraphData):
self._graph_data = GraphDataClient(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat,
node_type, edge_type, num_parallel_workers, working_mode, hostname,
port)
atexit.register(self.stop)
atexit.register(self._stop)
if working_mode == 'server':
self._graph_data = GraphDataServer(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat,
node_type, edge_type, num_parallel_workers, hostname, port, num_client,
auto_shutdown)
atexit.register(self.stop)
atexit.register(self._stop)
try:
while self._graph_data.is_stopped() is not True:
time.sleep(1)
except KeyboardInterrupt:
raise Exception("Graph data server receives KeyboardInterrupt.")
def stop(self):
"""Stop GraphDataClient or GraphDataServer."""
self._graph_data.stop()
@check_gnn_get_all_nodes
def get_all_nodes(self, node_type):
"""
@ -679,7 +677,8 @@ class Graph(GraphData):
Get all edges in the graph.
Args:
edge_type (int): Specify the type of edge.
edge_type (str): Specify the type of edge, default edge_type is '0' when init graph without specify
edge_type.
Returns:
numpy.ndarray, array of edges.
@ -1037,7 +1036,7 @@ class Graph(GraphData):
Returns:
dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type.
node_feature_type, edge_feature_type and graph_feature_type.
"""
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server.")
@ -1223,11 +1222,33 @@ class _UsersDatasetTemplate:
class InMemoryGraphDataset(GeneratorDataset):
"""
The basic Dataset for loading graph into memory.
Recommended to inherit this class, and implement your own method like 'process', 'save' and 'load'.
Basic Dataset for loading graph into memory.
Recommended to Implement your own dataset with inheriting this class, and implement your own method like 'process',
'save' and 'load'.
Args:
data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in
`process` method.
save_dir (str): relative directory for saving processed dataset, this directory is under `data_dir`.
column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset,
num of column name should be equal to num of item in return data when implement method like `__getitem__`.
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all samples).
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
Random accessible input is required. When this argument is specified, `num_samples` reflects the max
sample number of per shard.
shard_id (int, optional): The shard ID within `num_shards` (default=None). This argument must be specified only
when num_shards is also specified. Random accessible input is required.
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the Python operation is computational heavy (default=True).
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
data between processes. This is only used if python_multiprocessing is set to True (default 6 MB).
"""
def __init__(self, data_dir, column_names="graph", save_dir="./processed", num_parallel_workers=1,
def __init__(self, data_dir, save_dir="./processed", column_names="graph", num_samples=None, num_parallel_workers=1,
shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6):
self.graphs = []
self.data_dir = data_dir
@ -1244,25 +1265,25 @@ class InMemoryGraphDataset(GeneratorDataset):
setattr(source, k, v)
for k, v in self.__class__.__dict__.items():
setattr(source.__class__, k, getattr(self.__class__, k))
super().__init__(source, column_names=column_names, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
num_shards=num_shards, shard_id=shard_id, python_multiprocessing=python_multiprocessing,
max_rowsize=max_rowsize)
super().__init__(source, column_names=column_names, num_samples=num_samples,
num_parallel_workers=num_parallel_workers, shuffle=shuffle, num_shards=num_shards,
shard_id=shard_id, python_multiprocessing=python_multiprocessing, max_rowsize=max_rowsize)
def process(self):
"""
Override this method in your our dataset class.
Process method based on origin dataset, override this method in your our dataset class.
"""
raise NotImplementedError("'process' method should be implemented in your own logic.")
def save(self):
"""
Override this method in your our dataset class.
Save processed data into disk in numpy.npz format, you can also override this method in your dataset class.
"""
save_graphs(self.processed_path, self.graphs)
def load(self):
"""
Override this method in your our dataset class.
Load data from given(processed) path, you can also override this method in your dataset class.
"""
self.graphs = load_graphs(self.processed_path, num_parallel_workers=1)
@ -1280,14 +1301,32 @@ class InMemoryGraphDataset(GeneratorDataset):
class ArgoverseDataset(InMemoryGraphDataset):
"""
Load argoverse dataset and create graph.
Here argoverse dataset is public dataset for autonomous driving, current implement `ArgoverseDataset` is mainly for
loading Motion Forecasting Dataset in argoverse dataset, recommend to visit official website for more detail:
https://www.argoverse.org/av1.html#download-link.
Args:
data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in
`process` method.
column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset,
num of column name should be equal to num of item in return data when implement method like `__getitem__`.
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
option could be beneficial if the Python operation is computational heavy (default=True).
perf_mode(bool, optional): mode for obtaining higher performance when iterate created dataset(will call
`__getitem__` method in this process). Default True, will save all the data in graph
(like edge index, node feature and graph feature) into graph feature.
"""
def __init__(self, data_dir, column_names="graph", shuffle=None, num_parallel_workers=1,
def __init__(self, data_dir, column_names="graph", num_parallel_workers=1, shuffle=None,
python_multiprocessing=True, perf_mode=True):
# For high performance, here we store edge_index into graph_feature directly
self.perf_mode = perf_mode
super().__init__(data_dir, column_names, shuffle=shuffle, num_parallel_workers=num_parallel_workers,
python_multiprocessing=python_multiprocessing)
super().__init__(data_dir=data_dir, column_names=column_names, shuffle=shuffle,
num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing)
def __getitem__(self, index):
graph = self.graphs[index]
@ -1312,6 +1351,10 @@ class ArgoverseDataset(InMemoryGraphDataset):
"""
process method mainly refers to: https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py
"""
try:
import pandas as pd
except ImportError:
raise ImportError("Import pandas failed, recommend to install pandas with pip.")
def get_edge_full_connection(node_num, start_index=0):
"""