!40025 fix small bug for ArgoverseDataset

Merge pull request !40025 from ms_yan/small_fix
This commit is contained in:
i-robot 2022-08-09 02:32:56 +00:00 committed by Gitee
commit 19f850cecb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 19 additions and 4 deletions

View File

@ -10,7 +10,7 @@
参数:
- **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表默认值'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表默认值'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数,如下述示例,建议初始化时明确它的取值如:`column_names=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]`
- **num_parallel_workers** (int可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式默认值1。
- **shuffle** (bool可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性 `__getitem__` 才可以指定该参数。默认值None。
- **python_multiprocessing** (bool可选) - 启用Python多进程模式加速运算默认值True。当传入 `source` 的Python对象的计算量很大时开启此选项可能会有较好效果。

View File

@ -1236,6 +1236,7 @@ class _UsersDatasetTemplate:
"""
Internal class _ReInitTemplate.
"""
def __init__(self):
pass
@ -1385,7 +1386,9 @@ class ArgoverseDataset(InMemoryGraphDataset):
data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in
`process` method.
column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset,
num of column name should be equal to num of item in return data when implement method like `__getitem__`.
num of column name should be equal to num of item in return data when implement method like `__getitem__`,
recommend to specify it with
`column_names=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]` like the following example.
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
(default=None, expected order behavior shown in the table).
@ -1394,11 +1397,23 @@ class ArgoverseDataset(InMemoryGraphDataset):
perf_mode(bool, optional): mode for obtaining higher performance when iterate created dataset(will call
`__getitem__` method in this process). Default True, will save all the data in graph
(like edge index, node feature and graph feature) into graph feature.
Examples:
>>> from mindspore.dataset import ArgoverseDataset
>>>
>>> argoverse_dataset_dir = "/path/to/argoverse_dataset_directory"
>>> graph_dataset = ArgoverseDataset(data_dir=argoverse_dataset_dir,
... column_names=["edge_index", "x", "y", "cluster", "valid_len",
... "time_step_len"])
>>> for item in graph_dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
... pass
"""
def __init__(self, data_dir, column_names="graph", num_parallel_workers=1, shuffle=None,
python_multiprocessing=True, perf_mode=True):
# For high performance, here we store edge_index into graph_feature directly
if not isinstance(perf_mode, bool):
raise TypeError("Type of 'perf_mode' should be bool, but got {}.".format(type(perf_mode)))
self.perf_mode = perf_mode
super().__init__(data_dir=data_dir, column_names=column_names, shuffle=shuffle,
num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing)
@ -1406,8 +1421,8 @@ class ArgoverseDataset(InMemoryGraphDataset):
def __getitem__(self, index):
graph = self.graphs[index]
if self.perf_mode:
return graph.get_graph_feature(
feature_types=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"])
return tuple(graph.get_graph_feature(
feature_types=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]))
graph_info = graph.graph_info()
all_nodes = graph.get_all_nodes(graph_info["node_type"][0])