!40025 fix small bug for ArgoverseDataset
Merge pull request !40025 from ms_yan/small_fix
This commit is contained in:
commit
19f850cecb
|
@ -10,7 +10,7 @@
|
|||
|
||||
参数:
|
||||
- **data_dir** (str) - 加载数据集的目录,这里包含原始格式的数据,并将在 `process` 方法中被加载。
|
||||
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表,默认值:'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数。
|
||||
- **column_names** (Union[str, list[str]],可选) - dataset包含的单个列名或多个列名组成的列表,默认值:'Graph'。当实现类似 `__getitem__` 等方法时,列名的数量应该等于该方法中返回数据的条数,如下述示例,建议初始化时明确它的取值如:`column_names=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]`。
|
||||
- **num_parallel_workers** (int,可选) - 指定读取数据的工作进程数/线程数(由参数 `python_multiprocessing` 决定当前为多进程模式或多线程模式),默认值:1。
|
||||
- **shuffle** (bool,可选) - 是否混洗数据集。当实现的Dataset带有可随机访问属性( `__getitem__` )时,才可以指定该参数。默认值:None。
|
||||
- **python_multiprocessing** (bool,可选) - 启用Python多进程模式加速运算,默认值:True。当传入 `source` 的Python对象的计算量很大时,开启此选项可能会有较好效果。
|
||||
|
|
|
@ -1236,6 +1236,7 @@ class _UsersDatasetTemplate:
|
|||
"""
|
||||
Internal class _ReInitTemplate.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
@ -1385,7 +1386,9 @@ class ArgoverseDataset(InMemoryGraphDataset):
|
|||
data_dir (str): directory for loading dataset, here contains origin format data and will be loaded in
|
||||
`process` method.
|
||||
column_names (Union[str, list[str]], optional): single column name or list of column names of the dataset,
|
||||
num of column name should be equal to num of item in return data when implement method like `__getitem__`.
|
||||
num of column name should be equal to num of item in return data when implement method like `__getitem__`,
|
||||
recommend to specify it with
|
||||
`column_names=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]` like the following example.
|
||||
num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required.
|
||||
(default=None, expected order behavior shown in the table).
|
||||
|
@ -1394,11 +1397,23 @@ class ArgoverseDataset(InMemoryGraphDataset):
|
|||
perf_mode(bool, optional): mode for obtaining higher performance when iterate created dataset(will call
|
||||
`__getitem__` method in this process). Default True, will save all the data in graph
|
||||
(like edge index, node feature and graph feature) into graph feature.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset import ArgoverseDataset
|
||||
>>>
|
||||
>>> argoverse_dataset_dir = "/path/to/argoverse_dataset_directory"
|
||||
>>> graph_dataset = ArgoverseDataset(data_dir=argoverse_dataset_dir,
|
||||
... column_names=["edge_index", "x", "y", "cluster", "valid_len",
|
||||
... "time_step_len"])
|
||||
>>> for item in graph_dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
... pass
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, column_names="graph", num_parallel_workers=1, shuffle=None,
|
||||
python_multiprocessing=True, perf_mode=True):
|
||||
# For high performance, here we store edge_index into graph_feature directly
|
||||
if not isinstance(perf_mode, bool):
|
||||
raise TypeError("Type of 'perf_mode' should be bool, but got {}.".format(type(perf_mode)))
|
||||
self.perf_mode = perf_mode
|
||||
super().__init__(data_dir=data_dir, column_names=column_names, shuffle=shuffle,
|
||||
num_parallel_workers=num_parallel_workers, python_multiprocessing=python_multiprocessing)
|
||||
|
@ -1406,8 +1421,8 @@ class ArgoverseDataset(InMemoryGraphDataset):
|
|||
def __getitem__(self, index):
|
||||
graph = self.graphs[index]
|
||||
if self.perf_mode:
|
||||
return graph.get_graph_feature(
|
||||
feature_types=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"])
|
||||
return tuple(graph.get_graph_feature(
|
||||
feature_types=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]))
|
||||
|
||||
graph_info = graph.graph_info()
|
||||
all_nodes = graph.get_all_nodes(graph_info["node_type"][0])
|
||||
|
|
Loading…
Reference in New Issue