forked from mindspore-Ecosystem/mindspore
!12763 Fix C++ API testcases
From: @ezphlow Reviewed-by: Signed-off-by:
This commit is contained in:
commit
c07371f5c9
|
@ -25,6 +25,7 @@
|
||||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "minddata/dataset/engine/cache/cache_client.h"
|
#include "minddata/dataset/engine/cache/cache_client.h"
|
||||||
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
#include "minddata/dataset/engine/datasetops/cache_op.h"
|
||||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
|
@ -138,11 +138,12 @@ Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_s
|
||||||
|
|
||||||
// Helper function to validate dataset sampler parameter
|
// Helper function to validate dataset sampler parameter
|
||||||
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
|
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler) {
|
||||||
if (sampler == nullptr || sampler->ValidateParams().IsError()) {
|
if (sampler == nullptr) {
|
||||||
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
|
std::string err_msg = dataset_name + ": Sampler is not constructed correctly, sampler: nullptr";
|
||||||
MS_LOG(ERROR) << err_msg;
|
MS_LOG(ERROR) << err_msg;
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
|
RETURN_IF_NOT_OK(sampler->ValidateParams());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@
|
||||||
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
||||||
#include "minddata/dataset/engine/datasetops/take_op.h"
|
#include "minddata/dataset/engine/datasetops/take_op.h"
|
||||||
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
#include "minddata/dataset/engine/ir/cache/dataset_cache.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||||
#include "minddata/dataset/include/datasets.h"
|
#include "minddata/dataset/include/datasets.h"
|
||||||
#include "minddata/dataset/util/path.h"
|
#include "minddata/dataset/util/path.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
|
@ -45,9 +45,7 @@ class TensorRow;
|
||||||
class TensorShape;
|
class TensorShape;
|
||||||
class TreeAdapter;
|
class TreeAdapter;
|
||||||
class TreeGetters;
|
class TreeGetters;
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class Vocab;
|
class Vocab;
|
||||||
#endif
|
|
||||||
|
|
||||||
class DatasetCache;
|
class DatasetCache;
|
||||||
class DatasetNode;
|
class DatasetNode;
|
||||||
|
@ -64,31 +62,23 @@ class BatchDataset;
|
||||||
class MapDataset;
|
class MapDataset;
|
||||||
class ProjectDataset;
|
class ProjectDataset;
|
||||||
class ShuffleDataset;
|
class ShuffleDataset;
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class BucketBatchByLengthDataset;
|
class BucketBatchByLengthDataset;
|
||||||
class FilterDataset;
|
class FilterDataset;
|
||||||
class CSVDataset;
|
class CSVDataset;
|
||||||
class TransferDataset;
|
class TransferDataset;
|
||||||
class ConcatDataset;
|
class ConcatDataset;
|
||||||
class RenameDataset;
|
class RenameDataset;
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class SentencePieceVocab;
|
class SentencePieceVocab;
|
||||||
enum class SentencePieceModel;
|
enum class SentencePieceModel;
|
||||||
#endif
|
|
||||||
|
|
||||||
class DSCallback;
|
class DSCallback;
|
||||||
|
|
||||||
class RepeatDataset;
|
class RepeatDataset;
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class SkipDataset;
|
class SkipDataset;
|
||||||
class TakeDataset;
|
class TakeDataset;
|
||||||
class ZipDataset;
|
class ZipDataset;
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/// \class Dataset datasets.h
|
/// \class Dataset datasets.h
|
||||||
/// \brief A base class to represent a dataset in the data pipeline.
|
/// \brief A base class to represent a dataset in the data pipeline.
|
||||||
class Dataset : public std::enable_shared_from_this<Dataset> {
|
class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
|
@ -153,7 +143,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
|
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
/// \brief Function to transfer data through a device.
|
/// \brief Function to transfer data through a device.
|
||||||
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
|
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
|
||||||
/// of data transmission per time is 256M.
|
/// of data transmission per time is 256M.
|
||||||
|
@ -186,7 +175,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
|
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
|
||||||
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
|
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
/// \brief Function to create a BatchDataset
|
/// \brief Function to create a BatchDataset
|
||||||
/// \notes Combines batch_size number of consecutive rows into batches
|
/// \notes Combines batch_size number of consecutive rows into batches
|
||||||
|
@ -198,7 +186,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
/// \return Shared pointer to the current BatchDataset
|
/// \return Shared pointer to the current BatchDataset
|
||||||
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
|
std::shared_ptr<BatchDataset> Batch(int32_t batch_size, bool drop_remainder = false);
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
/// \brief Function to create a BucketBatchByLengthDataset
|
/// \brief Function to create a BucketBatchByLengthDataset
|
||||||
/// \notes Bucket elements according to their lengths. Each bucket will be padded and batched when
|
/// \notes Bucket elements according to their lengths. Each bucket will be padded and batched when
|
||||||
/// they are full.
|
/// they are full.
|
||||||
|
@ -293,7 +280,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
const std::vector<std::string> &input_columns = {}) {
|
const std::vector<std::string> &input_columns = {}) {
|
||||||
return std::make_shared<FilterDataset>(shared_from_this(), predicate, VectorStringToChar(input_columns));
|
return std::make_shared<FilterDataset>(shared_from_this(), predicate, VectorStringToChar(input_columns));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
/// \brief Function to create a MapDataset
|
/// \brief Function to create a MapDataset
|
||||||
/// \notes Applies each operation in operations to this dataset
|
/// \notes Applies each operation in operations to this dataset
|
||||||
|
@ -396,7 +382,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
|
return std::make_shared<ProjectDataset>(shared_from_this(), VectorStringToChar(columns));
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
/// \brief Function to create a Rename Dataset
|
/// \brief Function to create a Rename Dataset
|
||||||
/// \notes Renames the columns in the input dataset
|
/// \notes Renames the columns in the input dataset
|
||||||
/// \param[in] input_columns List of the input columns to rename
|
/// \param[in] input_columns List of the input columns to rename
|
||||||
|
@ -407,7 +392,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
return std::make_shared<RenameDataset>(shared_from_this(), VectorStringToChar(input_columns),
|
return std::make_shared<RenameDataset>(shared_from_this(), VectorStringToChar(input_columns),
|
||||||
VectorStringToChar(output_columns));
|
VectorStringToChar(output_columns));
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
/// \brief Function to create a RepeatDataset
|
/// \brief Function to create a RepeatDataset
|
||||||
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
|
/// \notes Repeats this dataset count times. Repeat indefinitely if count is -1
|
||||||
/// \param[in] count Number of times the dataset should be repeated
|
/// \param[in] count Number of times the dataset should be repeated
|
||||||
|
@ -417,7 +401,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
std::shared_ptr<RepeatDataset> Repeat(int32_t count = -1) {
|
std::shared_ptr<RepeatDataset> Repeat(int32_t count = -1) {
|
||||||
return std::make_shared<RepeatDataset>(shared_from_this(), count);
|
return std::make_shared<RepeatDataset>(shared_from_this(), count);
|
||||||
}
|
}
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
/// \brief Function to create a Shuffle Dataset
|
/// \brief Function to create a Shuffle Dataset
|
||||||
/// \notes Randomly shuffles the rows of this dataset
|
/// \notes Randomly shuffles the rows of this dataset
|
||||||
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
|
/// \param[in] buffer_size The size of the buffer (must be larger than 1) for shuffling
|
||||||
|
@ -449,7 +432,6 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
||||||
all_datasets.push_back(shared_from_this());
|
all_datasets.push_back(shared_from_this());
|
||||||
return std::make_shared<ZipDataset>(all_datasets);
|
return std::make_shared<ZipDataset>(all_datasets);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
|
std::shared_ptr<DatasetNode> IRNode() { return ir_node_; }
|
||||||
|
|
||||||
|
@ -602,7 +584,6 @@ class BatchDataset : public Dataset {
|
||||||
~BatchDataset() = default;
|
~BatchDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class BucketBatchByLengthDataset : public Dataset {
|
class BucketBatchByLengthDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
BucketBatchByLengthDataset(
|
BucketBatchByLengthDataset(
|
||||||
|
@ -626,7 +607,6 @@ class FilterDataset : public Dataset {
|
||||||
const std::vector<std::vector<char>> &input_columns);
|
const std::vector<std::vector<char>> &input_columns);
|
||||||
~FilterDataset() = default;
|
~FilterDataset() = default;
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
class MapDataset : public Dataset {
|
class MapDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
|
@ -643,14 +623,12 @@ class ProjectDataset : public Dataset {
|
||||||
~ProjectDataset() = default;
|
~ProjectDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class RenameDataset : public Dataset {
|
class RenameDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
|
RenameDataset(std::shared_ptr<Dataset> input, const std::vector<std::vector<char>> &input_columns,
|
||||||
const std::vector<std::vector<char>> &output_columns);
|
const std::vector<std::vector<char>> &output_columns);
|
||||||
~RenameDataset() = default;
|
~RenameDataset() = default;
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
class RepeatDataset : public Dataset {
|
class RepeatDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
|
@ -664,7 +642,6 @@ class ShuffleDataset : public Dataset {
|
||||||
~ShuffleDataset() = default;
|
~ShuffleDataset() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class SkipDataset : public Dataset {
|
class SkipDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
SkipDataset(std::shared_ptr<Dataset> input, int32_t count);
|
SkipDataset(std::shared_ptr<Dataset> input, int32_t count);
|
||||||
|
@ -682,7 +659,6 @@ class ZipDataset : public Dataset {
|
||||||
explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &inputs);
|
explicit ZipDataset(const std::vector<std::shared_ptr<Dataset>> &inputs);
|
||||||
~ZipDataset() = default;
|
~ZipDataset() = default;
|
||||||
};
|
};
|
||||||
#endif
|
|
||||||
|
|
||||||
/// \brief Function to create a SchemaObj
|
/// \brief Function to create a SchemaObj
|
||||||
/// \param[in] schema_file Path of schema file
|
/// \param[in] schema_file Path of schema file
|
||||||
|
@ -762,7 +738,6 @@ inline std::shared_ptr<AlbumDataset> Album(const std::string &dataset_dir, const
|
||||||
VectorStringToChar(column_names), decode, sampler, cache);
|
VectorStringToChar(column_names), decode, sampler, cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
class CelebADataset : public Dataset {
|
class CelebADataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
explicit CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
explicit CelebADataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
|
||||||
|
@ -1375,7 +1350,6 @@ inline std::shared_ptr<MindDataDataset> MindData(const std::vector<std::string>
|
||||||
return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
|
return std::make_shared<MindDataDataset>(VectorStringToChar(dataset_files), VectorStringToChar(columns_list), sampler,
|
||||||
padded_sample, num_padded);
|
padded_sample, num_padded);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
class MnistDataset : public Dataset {
|
class MnistDataset : public Dataset {
|
||||||
public:
|
public:
|
||||||
|
@ -1427,7 +1401,6 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
|
||||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||||
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||||
}
|
}
|
||||||
#ifndef ENABLE_ANDROID
|
|
||||||
|
|
||||||
/// \brief Function to create a ConcatDataset
|
/// \brief Function to create a ConcatDataset
|
||||||
/// \notes Reload "+" operator to concat two datasets
|
/// \notes Reload "+" operator to concat two datasets
|
||||||
|
@ -1686,7 +1659,6 @@ inline std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint
|
||||||
inline std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
inline std::shared_ptr<ZipDataset> Zip(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
||||||
return std::make_shared<ZipDataset>(datasets);
|
return std::make_shared<ZipDataset>(datasets);
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -608,7 +608,7 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
|
||||||
} else {
|
} else {
|
||||||
// we resize here since the shape changes
|
// we resize here since the shape changes
|
||||||
// create a new bounding box with the rotate
|
// create a new bounding box with the rotate
|
||||||
cv::Rect2f bbox = cv::RotatedRect(cv::Point2f(), input_img.size(), degree).boundingRect2f();
|
cv::Rect2f bbox = cv::RotatedRect(pc, input_img.size(), degree).boundingRect2f();
|
||||||
rot.at<double>(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0;
|
rot.at<double>(0, 2) += bbox.width / 2.0 - input_img.cols / 2.0;
|
||||||
rot.at<double>(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0;
|
rot.at<double>(1, 2) += bbox.height / 2.0 - input_img.rows / 2.0;
|
||||||
// use memcpy and don't compute the new shape since openCV has a rounding problem
|
// use memcpy and don't compute the new shape since openCV has a rounding problem
|
||||||
|
|
|
@ -3671,7 +3671,7 @@ class ManifestDataset(MappableDataset):
|
||||||
decode (bool, optional): decode the images after reading (default=False).
|
decode (bool, optional): decode the images after reading (default=False).
|
||||||
num_shards (int, optional): Number of shards that the dataset will be divided
|
num_shards (int, optional): Number of shards that the dataset will be divided
|
||||||
into (default=None). When this argument is specified, `num_samples` reflects
|
into (default=None). When this argument is specified, `num_samples` reflects
|
||||||
the max sample number of per shard.
|
the max number of samples per shard.
|
||||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||||
argument can only be specified when `num_shards` is also specified.
|
argument can only be specified when `num_shards` is also specified.
|
||||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||||
|
|
|
@ -74,6 +74,7 @@ SET(DE_UT_SRCS
|
||||||
image_process_test.cc
|
image_process_test.cc
|
||||||
interrupt_test.cc
|
interrupt_test.cc
|
||||||
ir_callback_test.cc
|
ir_callback_test.cc
|
||||||
|
ir_sampler_test.cc
|
||||||
ir_tensor_op_fusion_pass_test.cc
|
ir_tensor_op_fusion_pass_test.cc
|
||||||
ir_tree_adapter_test.cc
|
ir_tree_adapter_test.cc
|
||||||
ir_vision_test.cc
|
ir_vision_test.cc
|
||||||
|
|
|
@ -57,10 +57,10 @@ TEST_F(MindDataTestPipeline, TestAffineAPI) {
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
// auto image = row["image"];
|
auto image = row["image"];
|
||||||
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||||
|
EXPECT_EQ(row["image"].Shape().at(0), 256);
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
// EXPECT_EQ(row["image"].Shape()[0], 256);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EXPECT_EQ(i, 15);
|
EXPECT_EQ(i, 15);
|
||||||
|
|
|
@ -88,96 +88,6 @@ TEST_F(MindDataTestPipeline, TestImageFolderWithSamplers) {
|
||||||
iter->Stop();
|
iter->Stop();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestCalculateNumSamples) {
|
|
||||||
int64_t num_rows = 30; // dummy variable for number of rows in the dataset
|
|
||||||
std::shared_ptr<SamplerObj> sampl = std::make_shared<DistributedSamplerObj>(2, 1, false, 6, 1, -1, true);
|
|
||||||
EXPECT_NE(sampl, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt;
|
|
||||||
sampl->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);
|
|
||||||
|
|
||||||
sampl = std::make_shared<PKSamplerObj>(3, false, 0);
|
|
||||||
EXPECT_NE(sampl, nullptr);
|
|
||||||
sampl->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
|
|
||||||
|
|
||||||
sampl = std::make_shared<RandomSamplerObj>(false, 12);
|
|
||||||
EXPECT_NE(sampl, nullptr);
|
|
||||||
sampl->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
|
|
||||||
|
|
||||||
sampl = std::make_shared<SequentialSamplerObj>(0, 10);
|
|
||||||
EXPECT_NE(sampl, nullptr);
|
|
||||||
sampl->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);
|
|
||||||
|
|
||||||
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
|
|
||||||
sampl = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
|
|
||||||
EXPECT_NE(sampl, nullptr);
|
|
||||||
sampl->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
|
|
||||||
|
|
||||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
|
|
||||||
sampl = std::make_shared<SubsetRandomSamplerObj>(indices, 11);
|
|
||||||
EXPECT_NE(sampl, nullptr);
|
|
||||||
sampl->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);
|
|
||||||
|
|
||||||
// Testing chains
|
|
||||||
// Parent and child have num_samples
|
|
||||||
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
|
|
||||||
EXPECT_NE(sampl1, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt1;
|
|
||||||
sampl1->SamplerBuild(&sampler_rt1);
|
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SequentialSamplerObj>(0, 10);
|
|
||||||
EXPECT_NE(sampl2, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt2;
|
|
||||||
sampl2->SamplerBuild(&sampler_rt2);
|
|
||||||
sampler_rt2->AddChild(sampler_rt1);
|
|
||||||
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);
|
|
||||||
|
|
||||||
// Parent doesn't have num_samples
|
|
||||||
std::shared_ptr<SamplerObj> sampl3 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
|
|
||||||
EXPECT_NE(sampl3, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt3;
|
|
||||||
sampl3->SamplerBuild(&sampler_rt3);
|
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> sampl4 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
|
|
||||||
EXPECT_NE(sampl4, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt4;
|
|
||||||
sampl4->SamplerBuild(&sampler_rt4);
|
|
||||||
sampler_rt4->AddChild(sampler_rt3);
|
|
||||||
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11);
|
|
||||||
|
|
||||||
// Child doesn't have num_samples
|
|
||||||
std::shared_ptr<SamplerObj> sampl5 = std::make_shared<RandomSamplerObj>(false, 0);
|
|
||||||
EXPECT_NE(sampl5, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt5;
|
|
||||||
sampl5->SamplerBuild(&sampler_rt5);
|
|
||||||
|
|
||||||
std::shared_ptr<SamplerObj> sampl6 = std::make_shared<PKSamplerObj>(3, false, 7);
|
|
||||||
EXPECT_NE(sampl6, nullptr);
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt6;
|
|
||||||
sampl6->SamplerBuild(&sampler_rt6);
|
|
||||||
sampler_rt6->AddChild(sampler_rt5);
|
|
||||||
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
|
|
||||||
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
|
|
||||||
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
|
|
||||||
EXPECT_FALSE(indices.empty());
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
|
||||||
sampl1->SamplerBuild(&sampler_rt);
|
|
||||||
EXPECT_NE(sampler_rt, nullptr);
|
|
||||||
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), 0);
|
|
||||||
EXPECT_TRUE(indices.empty());
|
|
||||||
std::shared_ptr<SamplerRT> sampler_rt2 = nullptr;
|
|
||||||
sampl2->SamplerBuild(&sampler_rt2);
|
|
||||||
EXPECT_NE(sampler_rt, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(MindDataTestPipeline, TestNoSamplerSuccess1) {
|
TEST_F(MindDataTestPipeline, TestNoSamplerSuccess1) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNoSamplerSuccess1.";
|
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNoSamplerSuccess1.";
|
||||||
// Test building a dataset with no sampler provided (defaults to random sampler
|
// Test building a dataset with no sampler provided (defaults to random sampler
|
||||||
|
|
|
@ -37,8 +37,8 @@ TEST_F(MindDataTestPipeline, TestComposeSuccess) {
|
||||||
EXPECT_NE(ds, nullptr);
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
// Create objects for the tensor ops
|
// Create objects for the tensor ops
|
||||||
auto decode_op(new vision::Decode());
|
std::shared_ptr<TensorTransform> decode_op(new vision::Decode());
|
||||||
auto resize_op(new vision::Resize({777, 777}));
|
std::shared_ptr<TensorTransform> resize_op(new vision::Resize({777, 777}));
|
||||||
transforms::Compose compose({decode_op, resize_op});
|
transforms::Compose compose({decode_op, resize_op});
|
||||||
|
|
||||||
// Create a Map operation on ds
|
// Create a Map operation on ds
|
||||||
|
@ -493,10 +493,10 @@ TEST_F(MindDataTestPipeline, TestRandomChoiceSuccess) {
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
// auto image = row["image"];
|
auto image = row["image"];
|
||||||
// auto label = row["label"];
|
auto label = row["label"];
|
||||||
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||||
// MS_LOG(INFO) << "Label shape: " << label->shape();
|
MS_LOG(INFO) << "Label shape: " << label.Shape();
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -56,9 +56,9 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeRandomCropResizeJpegSuccess1) {
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
// auto image = row["image"];
|
auto image = row["image"];
|
||||||
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||||
// EXPECT_EQ(image->shape()[0] == 500 && image->shape()[1] == 500, true);
|
EXPECT_EQ(image.Shape()[0] == 500 && image.Shape()[1] == 500, true);
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,9 +98,9 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeRandomCropResizeJpegSuccess2) {
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
// auto image = row["image"];
|
auto image = row["image"];
|
||||||
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||||
// EXPECT_EQ(image->shape()[0] == 500 && image->shape()[1] == 600, true);
|
EXPECT_EQ(image.Shape()[0] == 500 && image.Shape()[1] == 600, true);
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeResizeJpegSuccess1) {
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
// auto image = row["image"];
|
// std::shared_ptr<TensorTransform> image = row["image"];
|
||||||
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
@ -180,8 +180,8 @@ TEST_F(MindDataTestPipeline, TestSoftDvppDecodeResizeJpegSuccess2) {
|
||||||
uint64_t i = 0;
|
uint64_t i = 0;
|
||||||
while (row.size() != 0) {
|
while (row.size() != 0) {
|
||||||
i++;
|
i++;
|
||||||
// auto image = row["image"];
|
auto image = row["image"];
|
||||||
// MS_LOG(INFO) << "Tensor image shape: " << image->shape();
|
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||||
iter->GetNextRow(&row);
|
iter->GetNextRow(&row);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -112,6 +112,8 @@ TEST_F(MindDataTestExecute, TestTransformInput2) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput2.";
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput2.";
|
||||||
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,
|
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,
|
||||||
// instantiated via new
|
// instantiated via new
|
||||||
|
// With this way of creating TensorTransforms, we don't need to explicitly delete the object created with the
|
||||||
|
// "new" keyword. When the shared pointer goes out of scope the object destructor will be called.
|
||||||
|
|
||||||
// Read image, construct MSTensor from dataset tensor
|
// Read image, construct MSTensor from dataset tensor
|
||||||
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
|
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
|
||||||
|
@ -171,6 +173,7 @@ TEST_F(MindDataTestExecute, TestTransformInput3) {
|
||||||
TEST_F(MindDataTestExecute, TestTransformInputSequential) {
|
TEST_F(MindDataTestExecute, TestTransformInputSequential) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInputSequential.";
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInputSequential.";
|
||||||
// Test Execute with transform op input using API constructors, with auto pointers;
|
// Test Execute with transform op input using API constructors, with auto pointers;
|
||||||
|
// Note that with auto and new, we have to explicitly delete the allocated object as shown below.
|
||||||
// Apply 2 transformations sequentially, including single non-vector Transform op input
|
// Apply 2 transformations sequentially, including single non-vector Transform op input
|
||||||
|
|
||||||
// Read image, construct MSTensor from dataset tensor
|
// Read image, construct MSTensor from dataset tensor
|
||||||
|
@ -207,7 +210,7 @@ TEST_F(MindDataTestExecute, TestTransformInputSequential) {
|
||||||
|
|
||||||
TEST_F(MindDataTestExecute, TestTransformDecodeResizeCenterCrop1) {
|
TEST_F(MindDataTestExecute, TestTransformDecodeResizeCenterCrop1) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformDecodeResizeCenterCrop1.";
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformDecodeResizeCenterCrop1.";
|
||||||
// Test Execute with Decode, Resize and CenterCrop transform ops input using API constructors, with auto pointers
|
// Test Execute with Decode, Resize and CenterCrop transform ops input using API constructors, with shared pointers
|
||||||
|
|
||||||
// Read image, construct MSTensor from dataset tensor
|
// Read image, construct MSTensor from dataset tensor
|
||||||
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
|
std::shared_ptr<mindspore::dataset::Tensor> de_tensor;
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||||
|
#include "minddata/dataset/engine/ir/datasetops/source/samplers/samplers_ir.h"
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::dataset::Tensor;
|
||||||
|
|
||||||
|
class MindDataTestIrSampler : public UT::DatasetOpTesting {
|
||||||
|
protected:
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestIrSampler, TestCalculateNumSamples) {
|
||||||
|
int64_t num_rows = 30; // dummy variable for number of rows in the dataset
|
||||||
|
std::shared_ptr<SamplerObj> sampl = std::make_shared<DistributedSamplerObj>(2, 1, false, 6, 1, -1, true);
|
||||||
|
EXPECT_NE(sampl, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt;
|
||||||
|
sampl->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);
|
||||||
|
|
||||||
|
sampl = std::make_shared<PKSamplerObj>(3, false, 0);
|
||||||
|
EXPECT_NE(sampl, nullptr);
|
||||||
|
sampl->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
|
||||||
|
|
||||||
|
sampl = std::make_shared<RandomSamplerObj>(false, 12);
|
||||||
|
EXPECT_NE(sampl, nullptr);
|
||||||
|
sampl->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
|
||||||
|
|
||||||
|
sampl = std::make_shared<SequentialSamplerObj>(0, 10);
|
||||||
|
EXPECT_NE(sampl, nullptr);
|
||||||
|
sampl->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);
|
||||||
|
|
||||||
|
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
|
||||||
|
sampl = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
|
||||||
|
EXPECT_NE(sampl, nullptr);
|
||||||
|
sampl->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
|
||||||
|
|
||||||
|
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
|
||||||
|
sampl = std::make_shared<SubsetRandomSamplerObj>(indices, 11);
|
||||||
|
EXPECT_NE(sampl, nullptr);
|
||||||
|
sampl->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);
|
||||||
|
|
||||||
|
// Testing chains
|
||||||
|
// Parent and child have num_samples
|
||||||
|
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
|
||||||
|
EXPECT_NE(sampl1, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt1;
|
||||||
|
sampl1->SamplerBuild(&sampler_rt1);
|
||||||
|
|
||||||
|
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SequentialSamplerObj>(0, 10);
|
||||||
|
EXPECT_NE(sampl2, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt2;
|
||||||
|
sampl2->SamplerBuild(&sampler_rt2);
|
||||||
|
sampler_rt2->AddChild(sampler_rt1);
|
||||||
|
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);
|
||||||
|
|
||||||
|
// Parent doesn't have num_samples
|
||||||
|
std::shared_ptr<SamplerObj> sampl3 = std::make_shared<WeightedRandomSamplerObj>(weights, 12);
|
||||||
|
EXPECT_NE(sampl3, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt3;
|
||||||
|
sampl3->SamplerBuild(&sampler_rt3);
|
||||||
|
|
||||||
|
std::shared_ptr<SamplerObj> sampl4 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
|
||||||
|
EXPECT_NE(sampl4, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt4;
|
||||||
|
sampl4->SamplerBuild(&sampler_rt4);
|
||||||
|
sampler_rt4->AddChild(sampler_rt3);
|
||||||
|
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11);
|
||||||
|
|
||||||
|
// Child doesn't have num_samples
|
||||||
|
std::shared_ptr<SamplerObj> sampl5 = std::make_shared<RandomSamplerObj>(false, 0);
|
||||||
|
EXPECT_NE(sampl5, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt5;
|
||||||
|
sampl5->SamplerBuild(&sampler_rt5);
|
||||||
|
|
||||||
|
std::shared_ptr<SamplerObj> sampl6 = std::make_shared<PKSamplerObj>(3, false, 7);
|
||||||
|
EXPECT_NE(sampl6, nullptr);
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt6;
|
||||||
|
sampl6->SamplerBuild(&sampler_rt6);
|
||||||
|
sampler_rt6->AddChild(sampler_rt5);
|
||||||
|
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestIrSampler, TestSamplersMoveParameters) {
|
||||||
|
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
|
||||||
|
std::shared_ptr<SamplerObj> sampl1 = std::make_shared<SubsetRandomSamplerObj>(indices, 0);
|
||||||
|
EXPECT_FALSE(indices.empty());
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||||
|
sampl1->SamplerBuild(&sampler_rt);
|
||||||
|
EXPECT_NE(sampler_rt, nullptr);
|
||||||
|
std::shared_ptr<SamplerObj> sampl2 = std::make_shared<SubsetRandomSamplerObj>(std::move(indices), 0);
|
||||||
|
EXPECT_TRUE(indices.empty());
|
||||||
|
std::shared_ptr<SamplerRT> sampler_rt2 = nullptr;
|
||||||
|
sampl2->SamplerBuild(&sampler_rt2);
|
||||||
|
EXPECT_NE(sampler_rt, nullptr);
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_samples():
|
||||||
|
manifest_file = "../data/dataset/testManifestData/test5trainimgs.json"
|
||||||
|
num_samples = 1
|
||||||
|
# sampler = ds.DistributedSampler(num_shards=1, shard_id=0, shuffle=False, num_samples=3, offset=1)
|
||||||
|
data1 = ds.ManifestDataset(
|
||||||
|
manifest_file, num_samples=num_samples, num_shards=3, shard_id=1
|
||||||
|
)
|
||||||
|
row_count = 0
|
||||||
|
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
row_count += 1
|
||||||
|
assert row_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_samples_tf():
|
||||||
|
logger.info("test_tfrecord_read_all_dataset")
|
||||||
|
schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json"
|
||||||
|
files = ["../data/dataset/testTFTestAllTypes/test.data"]
|
||||||
|
# here num samples indicate the rows per shard. Total rows in file = 12
|
||||||
|
ds1 = ds.TFRecordDataset(files, schema_file, num_samples=2)
|
||||||
|
count = 0
|
||||||
|
for _ in ds1.create_tuple_iterator(num_epochs=1):
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_num_samples_image_folder():
|
||||||
|
data_dir = "../data/dataset/testPK/data"
|
||||||
|
ds1 = ds.ImageFolderDataset(data_dir, num_samples=2, num_shards=2, shard_id=0)
|
||||||
|
count = 0
|
||||||
|
for _ in ds1.create_tuple_iterator(num_epochs=1):
|
||||||
|
count += 1
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_num_samples()
|
||||||
|
test_num_samples_tf()
|
||||||
|
test_num_samples_image_folder()
|
Loading…
Reference in New Issue