API Docs TFRecord

This commit is contained in:
David 2023-01-30 10:26:39 -05:00
parent 4c5e1b2617
commit d837d53612
4 changed files with 17 additions and 16 deletions

View File

@ -11,14 +11,12 @@ mindspore.dataset.TFRecordDataset
支持传入JSON文件路径或 mindspore.dataset.Schema 构造的对象。默认值None。
- **columns_list** (list[str], 可选) - 指定从TFRecord文件中读取的数据列。默认值None读取所有列。
- **num_samples** (int, 可选) - 指定从数据集中读取的样本数。默认值None读取全部样本。
- 如果 `num_samples` 为None并且numRows字段由参数 `schema` 定义)不存在,则读取所有数据集。
- 如果 `compression_type` 不是 None `num_samples` 为None并且numRows字段由参数 `schema` 定义的值大于0则读取所有数据集。
- 如果 `compression_type` 为None `num_samples` 为None并且numRows字段由参数 `schema` 定义的值大于0则读取numRows条数据。
- 如果 `num_samples` 和numRows字段由参数 `schema` 定义的值都大于0此时仅有参数 `num_samples` 生效且读取给定数量的数据。
- 如果 `compression_type` 不是 None并且提供了 `num_samples` ,那么 `num_samples` 将是为每个分片从压缩文件中读取的行数。
强烈建议在 `compression_type` 为 "GZIP" 或 "ZLIB" 时提供 `num_samples` 以避免性能下降。
- 如果没有提供 `num_samples` ,则需要对同一个文件进行多次解压以获取文件大小。
`num_samples` 的处理优先级如下:
- 如果 `num_samples` 的值大于0则读取 `num_samples` 条数据。
- 否则,如果 numRows字段由参数 `schema` 定义的值大于0则读取numRows条数据。
- 否则,则读取所有数据集。
`num_samples` 或numRows字段由参数 `schema` 定义)将是为每个分片从压缩文件中读取的行数。
强烈建议在 `compression_type` 为 "GZIP" 或 "ZLIB" 时提供 `num_samples` 或numRows字段由参数 `schema` 定义)以避免为了获取文件大小对同一个文件进行多次解压而导致性能下降的问题。
- **num_parallel_workers** (int, 可选) - 指定读取数据的工作线程数。默认值None使用mindspore.dataset.config中配置的线程数。
- **shuffle** (Union[bool, Shuffle], 可选) - 每个epoch中数据混洗的模式支持传入bool类型与枚举类型进行指定。默认值mindspore.dataset.Shuffle.GLOBAL。
@ -30,10 +28,9 @@ mindspore.dataset.TFRecordDataset
- **num_shards** (int, 可选) - 指定分布式训练时将数据集进行划分的分片数。默认值None。指定此参数后`num_samples` 表示每个分片的最大样本数。
- **shard_id** (int, 可选) - 指定分布式训练时使用的分片ID号。默认值None。只有当指定了 `num_shards` 时才能指定此参数。
- **shard_equal_rows** (bool, 可选) - 分布式训练时为所有分片获取等量的数据行数。默认值False。如果 `shard_equal_rows` 为False则可能会使得每个分片的数据条目不相等从而导致分布式训练失败。因此当每个TFRecord文件的数据数量不相等时建议将此参数设置为True。注意只有当指定了 `num_shards` 时才能指定此参数。当 `compression_type` `num_samples`提供时,`shard_equal_rows` 会被视为True。
- **shard_equal_rows** (bool, 可选) - 分布式训练时为所有分片获取等量的数据行数。默认值False。如果 `shard_equal_rows` 为False则可能会使得每个分片的数据条目不相等从而导致分布式训练失败。因此当每个TFRecord文件的数据数量不相等时建议将此参数设置为True。注意只有当指定了 `num_shards` 时才能指定此参数。当 `compression_type` 不是 None`num_samples` 或numRows字段由参数 `schema` 定义)提供时,`shard_equal_rows` 会被视为True。
- **cache** (DatasetCache, 可选) - 单节点数据缓存服务,用于加快数据集处理,详情请阅读 `单节点数据缓存 <https://www.mindspore.cn/tutorials/experts/zh-CN/master/dataset/cache.html>`_ 。默认值None不使用缓存。
- **compression_type** (str, 可选) - 用于所有文件的压缩类型必须是“”“GZIP”或“ZLIB”。默认值:None即空字符串。
这将自动为所有分片获得相等的行数( `shard_equal_rows` 被认为是True),从而不能有 `num_samples` 为None的情况。
异常:
- **ValueError** - `dataset_files` 参数所指向的文件无效或不存在。

View File

@ -84,13 +84,16 @@ Status TFReaderOp::Init() {
RETURN_IF_NOT_OK(CreateSchema(dataset_files_list_[0], columns_to_load_));
}
if (compression_type_ == CompressionType::NONE && total_rows_ == 0) {
if (total_rows_ == 0) {
total_rows_ = data_schema_->NumRows();
}
if (total_rows_ < 0) {
RETURN_STATUS_UNEXPECTED(
"[Internal ERROR] num_samples or num_rows for TFRecordDataset must be greater than 0, but got: " +
std::to_string(total_rows_));
} else if (compression_type_ == CompressionType::NONE && total_rows_ == 0) {
MS_LOG(WARNING) << "Since compression_type is set, but neither num_samples nor numRows (from schema file) "
<< "is provided, performance might be degraded.";
}
// Build the index with our files such that each file corresponds to a key id.

View File

@ -5747,7 +5747,8 @@ class DATASET_API TFRecordDataset : public Dataset {
/// when num_shards is also specified. (Default = 0).
/// \param[in] shard_equal_rows Get equal rows for all shards.
/// (Default = false, number of rows of each shard may be not equal).
/// When `compression_type` and `num_samples` are provided, `shard_equal_rows` will be implied as true.
/// When `compression_type` is "GZIP" or "ZLIB", and `num_samples` or numRows (parsed from `schema` ) is
/// provided, shard_equal_rows` will be implied as true.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] compression_type Compression type to use.
/// (Default = "", which means no compression is used).
@ -5786,7 +5787,8 @@ class DATASET_API TFRecordDataset : public Dataset {
/// when num_shards is also specified. (Default = 0).
/// \param[in] shard_equal_rows Get equal rows for all shards.
/// (Default = false, number of rows of each shard may be not equal).
/// When `compression_type` and `num_samples` are provided, `shard_equal_rows` will be implied as true.
/// When `compression_type` is "GZIP" or "ZLIB", and `num_samples` or numRows (parsed from `schema` ) is
/// provided, shard_equal_rows` will be implied as true.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] compression_type Compression type to use.
/// (Default = "", which means no compression is used).
@ -5829,7 +5831,8 @@ class DATASET_API TFRecordDataset : public Dataset {
/// when num_shards is also specified. (Default = 0).
/// \param[in] shard_equal_rows Get equal rows for all shards.
/// (Default = false, number of rows of each shard may be not equal).
/// When `compression_type` and `num_samples` are provided, `shard_equal_rows` will be implied as true.
/// When `compression_type` is "GZIP" or "ZLIB", and `num_samples` or numRows (parsed from `schema` ) is
/// provided, shard_equal_rows` will be implied as true.
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
/// \param[in] compression_type Compression type to use.
/// (Default = "", which means no compression is used).

View File

@ -291,8 +291,6 @@ class TFRecordDataset(SourceDataset, UnionBaseDataset):
Default: None, which means no cache is used.
compression_type (str, optional): The type of compression used for all files, must be either '', 'GZIP', or
'ZLIB'. Default: None, as in empty string.
This will automatically get equal rows for all shards (`shard_equal_rows` considered to be True) when
`num_samples` or numRows (parsed from `schema` ) is provided.
Raises:
ValueError: If dataset_files are not valid or do not exist.