!38227 add chinese api for Graph API
Merge pull request !38227 from ms_yan/CN_api_graph
This commit is contained in:
commit
eb4030b089
|
@ -0,0 +1,31 @@
|
|||
mindspore.dataset.ArgoverseDataset
|
||||
==================================
|
||||
|
||||
.. py:class:: mindspore.dataset.ArgoverseDataset(data_dir, column_names="graph", shuffle=None, num_parallel_workers=1, python_multiprocessing=True, perf_mode=True)
|
||||
|
||||
加载argoverse数据集并进行图(Graph)初始化。
|
||||
|
||||
Argoverse数据集是自动驾驶领域的公共数据集,当前实现的 `ArgoverseDataset` 主要用于加载argoverse数据集中运动预测(Motion Forecasting)场景的数据集,具体信息可访问官网了解:
|
||||
https://www.argoverse.org/av1.html#download-link
|
||||
|
||||
参数:
|
||||
- **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。
|
||||
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表,默认值:'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。
|
||||
- **num_parallel_workers** (int,可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式),默认值:1。
|
||||
- **shuffle** (bool,可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性( `__getitem__` )时,才可以指定该参数。默认值:None。
|
||||
- **python_multiprocessing** (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 `source` 的Python对象的计算量很大时,开启此选项可能会有较好效果。
|
||||
- **perf_mode** (bool,可选) - 遍历创建的dataset对象时获得更高性能的模式(在此过程中将调用 `__getitem__` 方法)。默认值:True,将Graph的所有数据(如边的索引、节点特征和图的特征)都作为图特征进行存储。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.add_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.b.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.c.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -0,0 +1,280 @@
|
|||
mindspore.dataset.Graph
|
||||
=======================
|
||||
|
||||
.. py:class:: mindspore.dataset.Graph(edges, node_feat=None, edge_feat=None, graph_feat=None, node_type=None, edge_type=None, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
|
||||
|
||||
主要用于存储图的结构信息和图特征属性,并提供图采样等能力。
|
||||
|
||||
该接口支持输入表示节点、边及其特征的NumPy数组,来进行图初始化。如果 `working_mode` 是默认的 `local` 模式,则不需要指定 `working_mode`、`hostname` 、 `port` 、 `num_client` 、 `auto_shutdown` 等输入参数。
|
||||
|
||||
参数:
|
||||
- **edges**(Union[list, numpy.ndarray]): 以COO格式表示的边,shape为 [2, num_edges]。
|
||||
- **node_feat**(dict, 可选): 节点的特征,输入数据格式应该是dict,其中key表示特征的类型,用字符串表示,比如'weight'等;value应该是shape为 [num_nodes, num_node_features] 的NumPy数组。
|
||||
- **edge_feat**(dict, 可选): 边的特征,输入数据格式应该是dict,其中key表示特征的类型,用字符串表示,比如'weight'等;value应该是shape为 [num_edges, num_edge_features] 的NumPy数组。
|
||||
- **graph_feat**(dict, 可选):附加特征,不能分配给 `node_feat` 或者 `edge_feat` ,输入数据格式应该是dict,key是特征的类型,用字符串表示; value应该是NumPy数组,其shape可以不受限制。
|
||||
- **node_type**(Union[list, numpy.ndarray], 可选): 节点的类型,每个元素都是字符串,表示每个节点的类型。如果未提供,则每个节点的默认类型为“0”。
|
||||
- **edge_type**(Union[list, numpy.ndarray], 可选): 边的类型,每个元素都是字符串,表示每条边的类型。如果未提供,则每条边的默认类型为“0”。
|
||||
- **num_parallel_workers** (int, 可选) - 读取数据的工作线程数,默认值:None,使用mindspore.dataset.config中配置的线程数。
|
||||
- **working_mode** (str, 可选) - 设置工作模式,目前支持'local'/'client'/'server',默认值:'local'。
|
||||
|
||||
- **local**:用于非分布式训练场景。
|
||||
- **client**:用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
|
||||
- **server**:用于分布式训练场景。服务器加载数据并可供客户端使用。
|
||||
|
||||
- **hostname** (str, 可选) - 图数据集服务器的主机名。该参数仅在工作模式设置为 'client' 或 'server' 时有效,默认值:'127.0.0.1'。
|
||||
- **port** (int, 可选) - 图数据服务器的端口,取值范围为1024-65535。此参数仅当工作模式设置为 'client' 或 'server' 时有效,默认值:50051。
|
||||
- **num_client** (int, 可选) - 期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为 'server' 时有效,默认值:1。
|
||||
- **auto_shutdown** (bool, 可选) - 当工作模式设置为 'server' 时有效。当连接的客户端数量达到 `num_client` ,且没有客户端正在连接时,服务器将自动退出,默认值:True。
|
||||
|
||||
异常:
|
||||
- **ValueError** - `num_parallel_workers` 参数超过系统最大线程数。
|
||||
- **ValueError** - `working_mode` 参数取值不为'local', 'client' 或 'server'。
|
||||
- **TypeError** - `hostname` 参数类型错误。
|
||||
- **ValueError** - `port` 参数不在范围[1024, 65535]内。
|
||||
- **ValueError** - `num_client` 参数不在范围[1, 255]内。
|
||||
|
||||
.. py:method:: get_all_edges(edge_type)
|
||||
|
||||
获取图的所有边。
|
||||
|
||||
参数:
|
||||
- **edge_type** (str) - 指定边的类型,Graph初始化未指定`edge_type`时,默认值为'0'。详见 `加载图数据集 <https://www.mindspore.cn/tutorials/zh-CN/master/advanced/dataset/augment_graph_data.html>`_ 。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含边的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `edge_type` 的类型不是string类型。
|
||||
|
||||
.. py:method:: get_all_neighbors(node_list, neighbor_type, output_format=OutputFormat.NORMAL)
|
||||
|
||||
获取 `node_list` 所有节点的相邻节点,以 `neighbor_type` 类型返回。格式的定义参见以下示例:1表示两个节点之间连接,0表示不连接。
|
||||
|
||||
.. list-table:: 邻接矩阵
|
||||
:widths: 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* -
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
* - 0
|
||||
- 0
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
* - 1
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
- 0
|
||||
* - 2
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
* - 3
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
|
||||
.. list-table:: 普通格式
|
||||
:widths: 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - src
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
* - dst_0
|
||||
- 1
|
||||
- 2
|
||||
- 0
|
||||
- 1
|
||||
* - dst_1
|
||||
- -1
|
||||
- -1
|
||||
- 3
|
||||
- -1
|
||||
|
||||
.. list-table:: COO格式
|
||||
:widths: 20 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - src
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 2
|
||||
- 3
|
||||
* - dst
|
||||
- 1
|
||||
- 2
|
||||
- 0
|
||||
- 3
|
||||
- 1
|
||||
|
||||
.. list-table:: CSR格式
|
||||
:widths: 40 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - offsetTable
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
-
|
||||
* - dstTable
|
||||
- 1
|
||||
- 2
|
||||
- 0
|
||||
- 3
|
||||
- 1
|
||||
|
||||
参数:
|
||||
- **node_list** (Union[list, numpy.ndarray]) - 给定的节点列表。
|
||||
- **neighbor_type** (str) - 指定相邻节点的类型。
|
||||
- **output_format** (OutputFormat, 可选) - 输出存储格式,默认值:mindspore.dataset.OutputFormat.NORMAL,取值范围:[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
|
||||
|
||||
返回:
|
||||
对于普通格式或COO格式,将返回numpy.ndarray类型的数组表示相邻节点。如果指定了CSR格式,将返回两个numpy.ndarray数组,第一个表示偏移表,第二个表示相邻节点。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `neighbor_type` 的类型不是string类型。
|
||||
|
||||
.. py:method:: get_all_nodes(node_type)
|
||||
|
||||
获取图中的所有节点。
|
||||
|
||||
参数:
|
||||
- **node_type** (str) - 指定节点的类型。Graph初始化未指定`edge_type`时,默认值为'0'。详见 `加载图数据集 <https://www.mindspore.cn/tutorials/zh-CN/master/advanced/dataset/augment_graph_data.html>`_ 。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含节点的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `node_type` 的类型不是string类型。
|
||||
|
||||
.. py:method:: get_edge_feature(edge_list, feature_types)
|
||||
|
||||
获取 `edge_list` 列表中边的特征,以 `feature_types` 类型返回。
|
||||
|
||||
参数:
|
||||
- **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
|
||||
- **feature_types** (Union[list, numpy.ndarray]) - 包含给定特征类型的列表,列表中每个元素是string类型。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含特征的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
.. py:method:: get_edges_from_nodes(node_list)
|
||||
|
||||
从节点获取边。
|
||||
|
||||
参数:
|
||||
- **node_list** (Union[list[tuple], numpy.ndarray]) - 含一个或多个图节点ID对的列表。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,含一个或多个边ID的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
.. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
|
||||
|
||||
获取 `node_list` 列表中节所有点的负样本相邻节点,以 `neg_neighbor_type` 类型返回。
|
||||
|
||||
参数:
|
||||
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
|
||||
- **neg_neighbor_num** (int) - 采样的相邻节点数量。
|
||||
- **neg_neighbor_type** (str) - 指定负样本相邻节点的类型。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含相邻节点的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `neg_neighbor_num` 的类型不为整型。
|
||||
- **TypeError** - 参数 `neg_neighbor_type` 的类型不是string类型。
|
||||
|
||||
.. py:method:: get_node_feature(node_list, feature_types)
|
||||
|
||||
获取 `node_list` 中节点的特征,以 `feature_types` 类型返回。
|
||||
|
||||
参数:
|
||||
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
|
||||
- **feature_types** (Union[list, numpy.ndarray]) - 指定特征的类型,类型列表中每个元素应该是string类型。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含特征的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
|
||||
.. py:method:: get_nodes_from_edges(edge_list)
|
||||
|
||||
从图中的边获取节点。
|
||||
|
||||
参数:
|
||||
- **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含节点的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `edge_list` 不为列表或ndarray。
|
||||
|
||||
.. py:method:: get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM)
|
||||
|
||||
获取已采样相邻节点信息。此API支持多跳相邻节点采样。即将上一次采样结果作为下一跳采样的输入,最多允许6跳。采样结果平铺成列表,格式为[input node, 1-hop sampling result, 2-hop samling result ...]
|
||||
|
||||
参数:
|
||||
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
|
||||
- **neighbor_nums** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点数。
|
||||
- **neighbor_types** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点类型。
|
||||
- **strategy** (SamplingStrategy, 可选) - 采样策略,默认值:mindspore.dataset.SamplingStrategy.RANDOM。取值范围:[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
|
||||
|
||||
- **SamplingStrategy.RANDOM**:随机抽样,带放回采样。
|
||||
- **SamplingStrategy.EDGE_WEIGHT**:以边缘权重为概率进行采样。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含相邻节点的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `neighbor_nums` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `neighbor_types` 的类型不为列表或numpy.ndarray。
|
||||
|
||||
.. py:method:: graph_info()
|
||||
|
||||
获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
|
||||
|
||||
返回:
|
||||
dict,图的元信息。键为 `node_num` 、 `node_type` 、 `node_feature_type` 、 `edge_num` 、 `edge_type` 、`edge_feature_type` 和 `graph_feature_type`。
|
||||
|
||||
.. py:method:: random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
|
||||
|
||||
在节点中的随机游走。
|
||||
|
||||
参数:
|
||||
- **target_nodes** (list[int]) - 随机游走中的起始节点列表。
|
||||
- **meta_path** (list[int]) - 每个步长的节点类型。
|
||||
- **step_home_param** (float, 可选) - 返回 `node2vec算法 <https://www.kdd.org/kdd2016/papers/files/rfp0218-groverA.pdf>`_ 中的超参,默认值:1.0。
|
||||
- **step_away_param** (float, 可选) - `node2vec算法 <https://www.kdd.org/kdd2016/papers/files/rfp0218-groverA.pdf>`_ 中的in和out超参,默认值:1.0。
|
||||
- **default_node** (int, 可选) - 如果找不到更多相邻节点,则为默认节点,默认值:-1,表示不给定节点。
|
||||
|
||||
返回:
|
||||
numpy.ndarray,包含节点的数组。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 参数 `target_nodes` 的类型不为列表或numpy.ndarray。
|
||||
- **TypeError** - 参数 `meta_path` 的类型不为列表或numpy.ndarray。
|
|
@ -0,0 +1,46 @@
|
|||
mindspore.dataset.InMemoryGraphDataset
|
||||
======================================
|
||||
|
||||
.. py:class:: mindspore.dataset.InMemoryGraphDataset(data_dir, save_dir="./processed", column_names="graph", num_samples=None, num_parallel_workers=1, shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6)
|
||||
|
||||
用于将图数据加载到内存中的Dataset基类。
|
||||
|
||||
建议通过继承这个基类来实现相应的Dataset,并实现对应的方法,比如'process'、'save'和'load'。
|
||||
|
||||
参数:
|
||||
- **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。
|
||||
- **save_dir** (str) - 保存处理后得到的数据集的相对目录,该目录位于 `data_dir` 下面。
|
||||
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表,默认值:'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。
|
||||
- **num_samples** (int,可选) - 指定从数据集中读取的样本数,默认值:None,读取全部样本。
|
||||
- **num_parallel_workers** (int,可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式),默认值:1。
|
||||
- **shuffle** (bool,可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性( `__getitem__` )时,才可以指定该参数。默认值:None。
|
||||
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数,默认值:None。指定此参数后, `num_samples` 表示每个分片的最大样本数。
|
||||
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号,默认值:None。只有当指定了 `num_shards` 时才能指定此参数。
|
||||
- **python_multiprocessing** (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 `source` 的Python对象的计算量很大时,开启此选项可能会有较好效果。
|
||||
- **max_rowsize** (int,可选) - 指定在多进程之间复制数据时,共享内存分配的最大空间,默认值:6,单位为MB。仅当参数 `python_multiprocessing` 设为True时,此参数才会生效。
|
||||
|
||||
.. py:method:: process()
|
||||
|
||||
与原始数据集相关的处理方法,建议在自定义的Dataset中重写此方法。
|
||||
|
||||
.. py:method:: save()
|
||||
|
||||
将经过 `process` 函数处理后的数据以 numpy.npz 格式保存到磁盘中,也可以在自己实现的Dataset类中自己实现这个方法。
|
||||
|
||||
.. py:method:: load()
|
||||
|
||||
从给定(处理好的)路径加载数据,也可以在自己实现的Dataset类中实现这个方法。
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.add_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.b.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.c.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.d.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.use_sampler.rst
|
||||
|
||||
.. include:: mindspore.dataset.Dataset.zip.rst
|
|
@ -131,8 +131,10 @@ mindspore.dataset
|
|||
.. mscnautosummary::
|
||||
:toctree: dataset
|
||||
|
||||
mindspore.dataset.ArgoverseDataset
|
||||
mindspore.dataset.Graph
|
||||
mindspore.dataset.GraphData
|
||||
|
||||
mindspore.dataset.InMemoryGraphDataset
|
||||
|
||||
采样器
|
||||
-------
|
||||
|
|
|
@ -113,7 +113,10 @@ Graph
|
|||
:nosignatures:
|
||||
:template: classtemplate_inherited.rst
|
||||
|
||||
mindspore.dataset.ArgoverseDataset
|
||||
mindspore.dataset.Graph
|
||||
mindspore.dataset.GraphData
|
||||
mindspore.dataset.InMemoryGraphDataset
|
||||
|
||||
Sampler
|
||||
--------
|
||||
|
|
|
@ -22,7 +22,6 @@ import random
|
|||
import time
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from mindspore._c_dataengine import GraphDataClient
|
||||
from mindspore._c_dataengine import GraphDataServer
|
||||
from mindspore._c_dataengine import Tensor
|
||||
|
@ -138,22 +137,18 @@ class GraphData:
|
|||
if working_mode in ['local', 'client']:
|
||||
self._graph_data = GraphDataClient(self.data_format, dataset_file, num_parallel_workers, working_mode,
|
||||
hostname, port)
|
||||
atexit.register(self.stop)
|
||||
atexit.register(self._stop)
|
||||
|
||||
if working_mode == 'server':
|
||||
self._graph_data = GraphDataServer(
|
||||
self.data_format, dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
|
||||
atexit.register(self.stop)
|
||||
atexit.register(self._stop)
|
||||
try:
|
||||
while self._graph_data.is_stopped() is not True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Graph data server receives KeyboardInterrupt.")
|
||||
|
||||
def stop(self):
|
||||
"""Stop GraphDataClient or GraphDataServer."""
|
||||
self._graph_data.stop()
|
||||
|
||||
@check_gnn_get_all_nodes
|
||||
def get_all_nodes(self, node_type):
|
||||
"""
|
||||
|
@ -530,23 +525,30 @@ class GraphData:
|
|||
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
|
||||
default_node).as_array()
|
||||
|
||||
def _stop(self):
|
||||
"""Stop GraphDataClient or GraphDataServer."""
|
||||
self._graph_data.stop()
|
||||
|
||||
|
||||
class Graph(GraphData):
|
||||
"""
|
||||
A graph object for storing Graph structure and feature data.
|
||||
A graph object for storing Graph structure and feature data, and provide capabilities such as graph sampling.
|
||||
|
||||
This class supports init graph With input numpy array data, which represent edge, node and its features.
|
||||
This class supports init graph With input numpy array data, which represent node, edge and its features.
|
||||
If working mode is `local`, there is no need to specify input arguments like `working_mode`, `hostname`, `port`,
|
||||
`num_client`, `auto_shutdown`.
|
||||
|
||||
Args:
|
||||
edges(Union[list, numpy.ndarray]): edges of graph in COO format with shape [2, num_edges].
|
||||
node_feat(dict, optional): feature of nodes, key is feature type, value should be numpy.array with shape
|
||||
[num_nodes, num_node_features], feature type should be string, like 'weight' etc.
|
||||
edge_feat(dict, optional): feature of edges, key is feature type, value should be numpy.array with shape
|
||||
[num_edges, num_edge_features], feature type should be string, like 'weight' etc.
|
||||
graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, key is
|
||||
feature type, value should be numpy.array.
|
||||
node_feat(dict, optional): feature of nodes, input data format should be dict, key is feature type, which is
|
||||
represented with string like 'weight' etc, value should be numpy.array with shape
|
||||
[num_nodes, num_node_features].
|
||||
edge_feat(dict, optional): feature of edges, input data format should be dict, key is feature type, which is
|
||||
represented with string like 'weight' etc, value should be numpy.array with shape
|
||||
[num_edges, num_edge_features].
|
||||
graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, input data
|
||||
format should be dict, key is feature type, which is represented with string, value should be numpy.array,
|
||||
its shape is not restricted.
|
||||
node_type(Union[list, numpy.ndarray], optional): type of nodes, each element should be string which represent
|
||||
type of corresponding node. If not provided, default type for each node is '0'.
|
||||
edge_type(Union[list, numpy.ndarray], optional): type of edges, each element should be string which represent
|
||||
|
@ -630,23 +632,19 @@ class Graph(GraphData):
|
|||
self._graph_data = GraphDataClient(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat,
|
||||
node_type, edge_type, num_parallel_workers, working_mode, hostname,
|
||||
port)
|
||||
atexit.register(self.stop)
|
||||
atexit.register(self._stop)
|
||||
|
||||
if working_mode == 'server':
|
||||
self._graph_data = GraphDataServer(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat,
|
||||
node_type, edge_type, num_parallel_workers, hostname, port, num_client,
|
||||
auto_shutdown)
|
||||
atexit.register(self.stop)
|
||||
atexit.register(self._stop)
|
||||
try:
|
||||
while self._graph_data.is_stopped() is not True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Graph data server receives KeyboardInterrupt.")
|
||||
|
||||
def stop(self):
|
||||
"""Stop GraphDataClient or GraphDataServer."""
|
||||
self._graph_data.stop()
|
||||
|
||||
@check_gnn_get_all_nodes
|
||||
def get_all_nodes(self, node_type):
|
||||
"""
|
||||
|
@ -679,7 +677,8 @@ class Graph(GraphData):
|
|||
Get all edges in the graph.
|
||||
|
||||
Args:
|
||||
edge_type (int): Specify the type of edge.
|
||||
edge_type (str): Specify the type of edge, default edge_type is '0' when init graph without specify
|
||||
edge_type.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of edges.
|
||||
|
@ -1037,7 +1036,7 @@ class Graph(GraphData):
|
|||
|
||||
Returns:
|
||||
dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
||||
node_feature_type and edge_feature_type.
|
||||
node_feature_type, edge_feature_type and graph_feature_type.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
|
@ -1223,11 +1222,33 @@ class _UsersDatasetTemplate:
|
|||
|
||||
class InMemoryGraphDataset(GeneratorDataset):
|
||||
"""
|
||||
The basic Dataset for loading graph into memory.
|
||||
Recommended to inherit this class, and implement your own method like 'process', 'save' and 'load'.
|
||||
Basic Dataset for loading graph into memory.
|
||||
|
||||
Recommended to Implement your own dataset with inheriting this class, and implement your own method like 'process',
|
||||
'save' and 'load'.
|
||||
|
||||
Args:
|
||||
data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in
|
||||
`process` method.
|
||||
save_dir (str): relative directory for saving processed dataset, this directory is under `data_dir`.
|
||||
column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset,
|
||||
num of column name should be equal to num of item in return data when implement method like `__getitem__`.
|
||||
num_samples (int, optional): The number of samples to be included in the dataset (default=None, all samples).
|
||||
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
|
||||
(default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
Random accessible input is required. When this argument is specified, `num_samples` reflects the max
|
||||
sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This argument must be specified only
|
||||
when num_shards is also specified. Random accessible input is required.
|
||||
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
|
||||
option could be beneficial if the Python operation is computational heavy (default=True).
|
||||
max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory allocation to copy
|
||||
data between processes. This is only used if python_multiprocessing is set to True (default 6 MB).
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, column_names="graph", save_dir="./processed", num_parallel_workers=1,
|
||||
def __init__(self, data_dir, save_dir="./processed", column_names="graph", num_samples=None, num_parallel_workers=1,
|
||||
shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6):
|
||||
self.graphs = []
|
||||
self.data_dir = data_dir
|
||||
|
@ -1244,25 +1265,25 @@ class InMemoryGraphDataset(GeneratorDataset):
|
|||
setattr(source, k, v)
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
setattr(source.__class__, k, getattr(self.__class__, k))
|
||||
super().__init__(source, column_names=column_names, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||
num_shards=num_shards, shard_id=shard_id, python_multiprocessing=python_multiprocessing,
|
||||
max_rowsize=max_rowsize)
|
||||
super().__init__(source, column_names=column_names, num_samples=num_samples,
|
||||
num_parallel_workers=num_parallel_workers, shuffle=shuffle, num_shards=num_shards,
|
||||
shard_id=shard_id, python_multiprocessing=python_multiprocessing, max_rowsize=max_rowsize)
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
Override this method in your our dataset class.
|
||||
Process method based on origin dataset, override this method in your our dataset class.
|
||||
"""
|
||||
raise NotImplementedError("'process' method should be implemented in your own logic.")
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
Override this method in your our dataset class.
|
||||
Save processed data into disk in numpy.npz format, you can also override this method in your dataset class.
|
||||
"""
|
||||
save_graphs(self.processed_path, self.graphs)
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Override this method in your our dataset class.
|
||||
Load data from given(processed) path, you can also override this method in your dataset class.
|
||||
"""
|
||||
self.graphs = load_graphs(self.processed_path, num_parallel_workers=1)
|
||||
|
||||
|
@ -1280,14 +1301,32 @@ class InMemoryGraphDataset(GeneratorDataset):
|
|||
class ArgoverseDataset(InMemoryGraphDataset):
|
||||
"""
|
||||
Load argoverse dataset and create graph.
|
||||
|
||||
Here argoverse dataset is public dataset for autonomous driving, current implement `ArgoverseDataset` is mainly for
|
||||
loading Motion Forecasting Dataset in argoverse dataset, recommend to visit official website for more detail:
|
||||
https://www.argoverse.org/av1.html#download-link.
|
||||
|
||||
Args:
|
||||
data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in
|
||||
`process` method.
|
||||
column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset,
|
||||
num of column name should be equal to num of item in return data when implement method like `__getitem__`.
|
||||
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
|
||||
(default=None, expected order behavior shown in the table).
|
||||
python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This
|
||||
option could be beneficial if the Python operation is computational heavy (default=True).
|
||||
perf_mode(bool, optional): mode for obtaining higher performance when iterate created dataset(will call
|
||||
`__getitem__` method in this process). Default True, will save all the data in graph
|
||||
(like edge index, node feature and graph feature) into graph feature.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, column_names="graph", shuffle=None, num_parallel_workers=1,
|
||||
def __init__(self, data_dir, column_names="graph", num_parallel_workers=1, shuffle=None,
|
||||
python_multiprocessing=True, perf_mode=True):
|
||||
# For high performance, here we store edge_index into graph_feature directly
|
||||
self.perf_mode = perf_mode
|
||||
super().__init__(data_dir, column_names, shuffle=shuffle, num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=python_multiprocessing)
|
||||
super().__init__(data_dir=data_dir, column_names=column_names, shuffle=shuffle,
|
||||
num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing)
|
||||
|
||||
def __getitem__(self, index):
|
||||
graph = self.graphs[index]
|
||||
|
@ -1312,6 +1351,10 @@ class ArgoverseDataset(InMemoryGraphDataset):
|
|||
"""
|
||||
process method mainly refers to: https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py
|
||||
"""
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError("Import pandas failed, recommend to install pandas with pip.")
|
||||
|
||||
def get_edge_full_connection(node_num, start_index=0):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue