mindspore/docs/api/api_python/dataset/mindspore.dataset.GraphData...

358 lines
13 KiB
ReStructuredText
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

mindspore.dataset.GraphData
===========================
.. py:class:: mindspore.dataset.GraphData(dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1, auto_shutdown=True)
从共享文件和数据库中读取用于GNN训练的图数据集。
**参数:**
- **dataset_file** (str) - 数据集文件路径。
- **num_parallel_workers** (int, 可选) - 读取数据的工作线程数默认为None
- **working_mode** (str, 可选) - 设置工作模式,目前支持'local'/'client'/'server'(默认为'local')。
- **local**:用于非分布式训练场景。
- **client**:用于分布式训练场景。客户端不加载数据,而是从服务器获取数据。
- **server**:用于分布式训练场景。服务器加载数据并可供客户端使用。
- **hostname** (str, 可选) - 图数据集服务器的主机名。该参数仅在工作模式设置为 'client' 或 'server' 时有效(默认为'127.0.0.1')。
- **port** (int, 可选) - 图数据服务器的端口取值范围为1024-65535。此参数仅当工作模式设置为 'client' 或 'server' 默认为50051时有效。
- **num_client** (int, 可选) - 期望连接到服务器的最大客户端数。服务器将根据该参数分配资源。该参数仅在工作模式设置为 'server' 时有效默认为1
- **auto_shutdown** (bool, 可选) - 当工作模式设置为 'server' 时有效。当连接的客户端数量达到 `num_client` 且没有客户端正在连接时服务器将自动退出默认为True
**样例:**
>>> graph_dataset_dir = "/path/to/graph_dataset_file"
>>> graph_dataset = ds.GraphData(dataset_file=graph_dataset_dir, num_parallel_workers=2)
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[1])
.. py:method:: get_all_edges(edge_type)
获取图的所有边。
**参数:**
- **edge_type** (int) - 指定边的类型。
**返回:**
numpy.ndarray包含边的数组。
**样例:**
>>> edges = graph_dataset.get_all_edges(edge_type=0)
**异常:**
**TypeError**:参数 `edge_type` 的类型不为整型。
.. py:method:: get_all_neighbors(node_list, neighbor_type, output_format=<OutputFormat.NORMAL: 0。
获取 `node_list` 所有节点的相邻节点,以 `neighbor_type` 类型返回。格式的定义参见以下示例1表示两个节点之间连接0表示不连接。
.. list-table:: 邻接矩阵
:widths: 20 20 20 20 20
:header-rows: 1
* -
- 0
- 1
- 2
- 3
* - 0
- 0
- 1
- 0
- 0
* - 1
- 0
- 0
- 1
- 0
* - 2
- 1
- 0
- 0
- 1
* - 3
- 1
- 0
- 0
- 0
.. list-table:: 普通格式
:widths: 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 3
* - dst_0
- 1
- 2
- 0
- 1
* - dst_1
- -1
- -1
- 3
- -1
.. list-table:: COO格式
:widths: 20 20 20 20 20 20
:header-rows: 1
* - src
- 0
- 1
- 2
- 2
- 3
* - dst
- 1
- 2
- 0
- 3
- 1
.. list-table:: CSR格式
:widths: 40 20 20 20 20 20
:header-rows: 1
* - offsetTable
- 0
- 1
- 2
- 4
-
* - dstTable
- 1
- 2
- 0
- 3
- 1
**参数:**
- **node_list** (Union[list, numpy.ndarray]) - 给定的节点列表。
- **neighbor_type** (int) - 指定相邻节点的类型。
- **output_format** (OutputFormat, 可选) - 输出存储格式默认为mindspore.dataset.engine.OutputFormat.NORMAL取值范围[OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]。
**返回:**
对于普通格式或COO格式将返回numpy.ndarray类型的数组表示相邻节点。如果指定了CSR格式将返回两个numpy.ndarray数组第一个表示偏移表第二个表示相邻节点。
**样例:**
>>> from mindspore.dataset.engine import OutputFormat
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
... output_format=OutputFormat.COO)
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
... output_format=OutputFormat.CSR)
**异常:**
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neighbor_type` 的类型不为整型。
.. py:method:: get_all_nodes(node_type)
获取图中的所有节点。
**参数:**
- **node_type** (int) - 指定节点的类型。
**返回:**
numpy.ndarray包含节点的数组。
**样例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
**异常:**
**TypeError**:参数 `node_type` 的类型不为整型。
.. py:method:: get_edges_from_nodes(node_list)
从节点获取边。
**参数:**
- **node_list** (Union[list[tuple], numpy.ndarray]) - 含一个或多个图节点ID对的列表。
**返回:**
numpy.ndarray含一个或多个边ID的数组。
**示例:**
>>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
**异常:**
**TypeError**:参数 `edge_list` 的类型不为列表或numpy.ndarray。
.. py:method:: get_edge_feature(edge_list, feature_types)
获取 `edge_list` 列表中边的特征,以 `feature_types` 类型返回。
**参数:**
- **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
- **feature_types** (Union[list, numpy.ndarray]) - 包含给定特征类型的列表。
**返回:**
numpy.ndarray包含特征的数组。
**样例:**
>>> edges = graph_dataset.get_all_edges(edge_type=0)
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=[1])
**异常:**
- **TypeError** - 参数 `edge_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
.. py:method:: get_neg_sampled_neighbors(node_list, neg_neighbor_num, neg_neighbor_type)
获取 `node_list` 列表中节所有点的负样本相邻节点,以 `neg_neighbor_type` 类型返回。
**参数:**
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
- **neg_neighbor_num** (int) - 采样的相邻节点数量。
- **neg_neighbor_type** (int) - 指定负样本相邻节点的类型。
**返回:**
numpy.ndarray包含相邻节点的数组。
**样例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
... neg_neighbor_type=2)
**异常:**
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neg_neighbor_num` 的类型不为整型。
- **TypeError** - 参数 `neg_neighbor_type` 的类型不为整型。
.. py:method:: get_nodes_from_edges(edge_list)
从图中的边获取节点。
**参数:**
- **edge_list** (Union[list, numpy.ndarray]) - 包含边的列表。
**返回:**
numpy.ndarray包含节点的数组。
**异常:**
**TypeError** 参数 `edge_list` 不为列表或ndarray。
.. py:method:: get_node_feature(node_list, feature_types)
获取 `node_list` 中节点的特征,以 `feature_types` 类型返回。
**参数:**
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
- **feature_types** (Union[list, numpy.ndarray]) - 指定特征的类型。
**返回:**
numpy.ndarray包含特征的数组。
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=[2, 3])
**异常:**
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `feature_types` 的类型不为列表或numpy.ndarray。
.. py:method:: get_sampled_neighbors(node_list, neighbor_nums, neighbor_types, strategy=<SamplingStrategy.RANDOM: 0>)
获取已采样相邻节点信息。此API支持多跳相邻节点采样。即将上一次采样结果作为下一跳采样的输入最多允许6跳。采样结果平铺成列表格式为[input node, 1-hop sampling result, 2-hop samling result ...]
**参数:**
- **node_list** (Union[list, numpy.ndarray]) - 包含节点的列表。
- **neighbor_nums** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点数。
- **neighbor_types** (Union[list, numpy.ndarray]) - 每跳采样的相邻节点类型。
- **strategy** (SamplingStrategy, 可选) - 采样策略默认为mindspore.dataset.engine.SamplingStrategy.RANDOM。取值范围[SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT]。
- **SamplingStrategy.RANDOM**:随机抽样,带放回采样。
- **SamplingStrategy.EDGE_WEIGHT**:以边缘权重为概率进行采样。
**返回:**
numpy.ndarray包含相邻节点的数组。
*样例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
... neighbor_types=[2, 1])
**异常:**
- **TypeError** - 参数 `node_list` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neighbor_nums` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `neighbor_types` 的类型不为列表或numpy.ndarray。
.. py:method:: graph_info()
获取图的元信息,包括节点数、节点类型、节点特征信息、边数、边类型、边特征信息。
**返回:**
dict图的元信息。键为 `node_num` 、 `node_type` 、 `node_feature_type` 、 `edge_num` 、 `edge_type` 、和 `edge_feature_type` 。
.. py:method:: random_walk(target_nodes, meta_path, step_home_param=1.0, step_away_param=1.0, default_node=-1)
在节点中的随机游走。
**参数:**
- **target_nodes** (list[int]) - 随机游走中的起始节点列表。
- **meta_path** (list[int]) - 每个步长的节点类型。
- **step_home_param** (float, 可选) - 返回node2vec算法中的超参默认为1.0)。
- **step_away_param** (float, 可选) - node2vec算法中的in和out超参默认为1.0)。
- **default_node** (int, 可选) - 如果找不到更多相邻节点,则为默认节点(默认值为-1表示不给定节点
**返回:**
numpy.ndarray包含节点的数组。
**示例:**
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> walks = graph_dataset.random_walk(target_nodes=nodes, meta_path=[2, 1, 2])
**异常:**
- **TypeError** - 参数 `target_nodes` 的类型不为列表或numpy.ndarray。
- **TypeError** - 参数 `meta_path` 的类型不为列表或numpy.ndarray。