forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6V3] add new data operator FakeImage
This commit is contained in:
parent
94a7298690
commit
d2f22a8726
|
@ -97,6 +97,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/random_node.h"
|
||||
|
@ -1070,6 +1071,30 @@ EMnistDataset::EMnistDataset(const std::vector<char> &dataset_dir, const std::ve
|
|||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
|
||||
auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FakeImageDataset::FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
auto sampler_obj = sampler.get().Parse();
|
||||
auto ds = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed, sampler_obj, cache);
|
||||
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
|
||||
}
|
||||
|
||||
FlickrDataset::FlickrDataset(const std::vector<char> &dataset_dir, const std::vector<char> &annotation_file,
|
||||
bool decode, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache) {
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "minddata/dataset/engine/ir/datasetops/source/csv_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/div2k_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/emnist_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/flickr_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
@ -164,6 +165,18 @@ PYBIND_REGISTER(EMnistNode, 2, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FakeImageNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FakeImageNode, DatasetNode, std::shared_ptr<FakeImageNode>>(
|
||||
*m, "FakeImageNode", "to create a FakeImageNode")
|
||||
.def(py::init([](int32_t num_images, const std::vector<int32_t> image_size, int32_t num_classes,
|
||||
int32_t base_seed, py::handle sampler) {
|
||||
auto fake_image = std::make_shared<FakeImageNode>(num_images, image_size, num_classes, base_seed,
|
||||
toSamplerObj(sampler), nullptr);
|
||||
THROW_IF_ERROR(fake_image->ValidateParams());
|
||||
return fake_image;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
FlickrNode, 2, ([](const py::module *m) {
|
||||
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
|
||||
|
|
|
@ -23,6 +23,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
|||
flickr_op.cc
|
||||
qmnist_op.cc
|
||||
emnist_op.cc
|
||||
fake_image_op.cc
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/datasetops/source/fake_image_op.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
FakeImageOp::FakeImageOp(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, int32_t num_workers, int32_t op_connector_size,
|
||||
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
|
||||
: MappableLeafOp(num_workers, op_connector_size, std::move(sampler)),
|
||||
num_images_(num_images),
|
||||
image_size_(image_size),
|
||||
num_classes_(num_classes),
|
||||
base_seed_(base_seed),
|
||||
image_tensor_({}),
|
||||
data_schema_(std::move(data_schema)) {}
|
||||
|
||||
// Load 1 TensorRow (image, label) using 1 trow.
|
||||
Status FakeImageOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
|
||||
RETURN_UNEXPECTED_IF_NULL(trow);
|
||||
std::shared_ptr<Tensor> image, label;
|
||||
|
||||
auto images_buf = std::make_unique<double[]>(image_total_size_);
|
||||
auto pixels = &images_buf[0];
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(access_mutex_);
|
||||
if (image_tensor_[row_id] == nullptr) {
|
||||
rand_gen_.seed(base_seed_ + row_id); // set seed for random generator.
|
||||
std::normal_distribution<double> distribution(0.0, 1.0);
|
||||
for (int i = 0; i < image_total_size_; ++i) {
|
||||
pixels[i] = distribution(rand_gen_); // generate the Gaussian distribution pixel.
|
||||
if (pixels[i] < 0) {
|
||||
pixels[i] = 0;
|
||||
}
|
||||
if (pixels[i] > 255) {
|
||||
pixels[i] = 255;
|
||||
}
|
||||
}
|
||||
TensorShape img_tensor_shape = TensorShape({image_size_[0], image_size_[1], image_size_[2]});
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(img_tensor_shape, data_schema_->Column(0).Type(),
|
||||
reinterpret_cast<unsigned char *>(pixels), &image));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(image, &image_tensor_[row_id]));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(image_tensor_[row_id], &image));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar(label_list_[row_id], &label));
|
||||
(*trow) = TensorRow(row_id, {std::move(image), std::move(label)});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A print method typically used for debugging.
|
||||
void FakeImageOp::Print(std::ostream &out, bool show_all) const {
|
||||
if (!show_all) {
|
||||
// Call the super class for displaying any common 1-liner info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
} else {
|
||||
// Call the super class for displaying any common detailed info.
|
||||
ParallelOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff.
|
||||
out << "\nNumber of images: " << num_images_ << "\nNumber of classes: " << num_classes_
|
||||
<< "\nBase seed: " << base_seed_ << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Derived from RandomAccessOp.
|
||||
Status FakeImageOp::GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const {
|
||||
if (cls_ids == nullptr || !cls_ids->empty() || label_list_.empty()) {
|
||||
if (label_list_.empty()) {
|
||||
RETURN_STATUS_UNEXPECTED("No image found in dataset. Check if image was generated successfully.");
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"[Internal ERROR] Map for storing image-index pair is nullptr or has been set in other place, "
|
||||
"it must be empty before using GetClassIds.");
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < label_list_.size(); ++i) {
|
||||
(*cls_ids)[label_list_[i]].push_back(i);
|
||||
}
|
||||
for (auto &pr : (*cls_ids)) {
|
||||
pr.second.shrink_to_fit();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FakeImageOp::GetItem(int32_t index) {
|
||||
// generate one target label according to index and save it in label_list_.
|
||||
rand_gen_.seed(base_seed_ + index); // set seed for random generator.
|
||||
std::uniform_int_distribution<int32_t> dist(0, num_classes_ - 1);
|
||||
uint32_t target = dist(rand_gen_); // generate the target.
|
||||
label_list_.emplace_back(target);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FakeImageOp::PrepareData() {
|
||||
// FakeImage generate image with Gaussian distribution.
|
||||
image_total_size_ = image_size_[0] * image_size_[1] * image_size_[2];
|
||||
|
||||
for (size_t i = 0; i < num_images_; ++i) {
|
||||
RETURN_IF_NOT_OK(GetItem(i));
|
||||
}
|
||||
|
||||
label_list_.shrink_to_fit();
|
||||
num_rows_ = label_list_.size();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "Generate image failed, please check dataset API.");
|
||||
image_tensor_.clear();
|
||||
image_tensor_.resize(num_rows_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FakeImageOp::ComputeColMap() {
|
||||
// Extract the column name mapping from the schema and save it in the class.
|
||||
if (column_name_id_map_.empty()) {
|
||||
RETURN_IF_NOT_OK(data_schema_->GetColumnNameMap(&(column_name_id_map_)));
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Column name map is already set!";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,112 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
#include "minddata/dataset/engine/datasetops/parallel_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "minddata/dataset/util/path.h"
|
||||
#include "minddata/dataset/util/queue.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/util/wait_post.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FakeImageOp : public MappableLeafOp {
|
||||
public:
|
||||
// Constructor.
|
||||
// @param int32_t num_images - Number of generated fake images.
|
||||
// @param const std::vector<int32_t> &image_size - The size of fake image.
|
||||
// @param int32_t num_classes - Number of classes in fake images.
|
||||
// @param int32_t base_seed - A base seed which is used in generating fake image randomly.
|
||||
// @param int32_t num_workers - Number of workers reading images in parallel.
|
||||
// @param int32_t op_connector_size - Connector queue size.
|
||||
// @param std::unique_ptr<DataSchema> data_schema - The schema of the fake image dataset.
|
||||
// @param td::unique_ptr<Sampler> sampler - Sampler tells FakeImageOp what to read.
|
||||
FakeImageOp(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed,
|
||||
int32_t num_workers, int32_t op_connector_size, std::unique_ptr<DataSchema> data_schema,
|
||||
std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
// Destructor.
|
||||
~FakeImageOp() = default;
|
||||
|
||||
// Method derived from RandomAccess Op, enable Sampler to get all ids for each class.
|
||||
// @param std::map<int32_t, std::vector<int64_t>> *cls_ids - Key label, val all ids for this class.
|
||||
// @return Status The status code returned.
|
||||
Status GetClassIds(std::map<int32_t, std::vector<int64_t>> *cls_ids) const override;
|
||||
|
||||
// A print method typically used for debugging.
|
||||
// @param out - The output stream to write output to.
|
||||
// @param show_all - A bool to control if you want to show all info or just a summary.
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
// Function to count the number of samples in the FakeImage dataset.
|
||||
// @return Number of images.
|
||||
int64_t GetTotalRows() const { return num_images_; }
|
||||
|
||||
// Op name getter.
|
||||
// @return Name of the current Op.
|
||||
std::string Name() const override { return "FakeImageOp"; }
|
||||
|
||||
// Get a image from index
|
||||
// @param int32_t index - Generate one image according to index.
|
||||
Status GetItem(int32_t index);
|
||||
|
||||
private:
|
||||
// Load a tensor row according to a lable_list.
|
||||
// @param row_id_type row_id - Id for this tensor row.
|
||||
// @param TensorRow *row - Image & label read into this tensor row.
|
||||
// @return Status The status code returned.
|
||||
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
|
||||
|
||||
// Generate all labels of FakeImage dataset
|
||||
// @return Status The status code returned.
|
||||
Status PrepareData();
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return Status The status code returned.
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t num_images_;
|
||||
int32_t base_seed_;
|
||||
std::vector<int> image_size_;
|
||||
int32_t num_classes_;
|
||||
|
||||
int64_t rows_per_buffer_;
|
||||
std::unique_ptr<DataSchema> data_schema_;
|
||||
|
||||
int32_t image_total_size_;
|
||||
std::vector<uint32_t> label_list_;
|
||||
std::vector<std::shared_ptr<Tensor>> image_tensor_;
|
||||
std::mt19937 rand_gen_;
|
||||
std::mutex access_mutex_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_FAKE_IMAGE_OP_H_
|
|
@ -84,6 +84,7 @@ constexpr char kCocoNode[] = "CocoDataset";
|
|||
constexpr char kCSVNode[] = "CSVDataset";
|
||||
constexpr char kDIV2KNode[] = "DIV2KDataset";
|
||||
constexpr char kEMnistNode[] = "EMnistDataset";
|
||||
constexpr char kFakeImageNode[] = "FakeImageDataset";
|
||||
constexpr char kFlickrNode[] = "FlickrDataset";
|
||||
constexpr char kGeneratorNode[] = "GeneratorDataset";
|
||||
constexpr char kImageFolderNode[] = "ImageFolderDataset";
|
||||
|
|
|
@ -13,6 +13,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
|
|||
csv_node.cc
|
||||
div2k_node.cc
|
||||
emnist_node.cc
|
||||
fake_image_node.cc
|
||||
flickr_node.cc
|
||||
image_folder_node.cc
|
||||
manifest_node.cc
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/fake_image_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/fake_image_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
FakeImageNode::FakeImageNode(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
num_images_(num_images),
|
||||
image_size_(image_size),
|
||||
num_classes_(num_classes),
|
||||
base_seed_(base_seed),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> FakeImageNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
auto node = std::make_shared<FakeImageNode>(num_images_, image_size_, num_classes_, base_seed_, sampler, cache_);
|
||||
return node;
|
||||
}
|
||||
|
||||
void FakeImageNode::Print(std::ostream &out) const {
|
||||
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
|
||||
}
|
||||
|
||||
Status FakeImageNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("FakeImageNode", sampler_));
|
||||
if (num_images_ <= 0) {
|
||||
std::string err_msg = "FakeImageNode: num_images must be greater than 0, but got: " + std::to_string(num_images_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (image_size_.size() != 3) {
|
||||
std::string err_msg =
|
||||
"FakeImageNode: image_size expecting size 3, but got image_size.size(): " + std::to_string(image_size_.size());
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
for (auto i = 0; i < 3; i++) {
|
||||
if (image_size_[i] <= 0) {
|
||||
std::string err_msg = "FakeImageNode: image_size[" + std::to_string(i) +
|
||||
"] must be greater than 0, but got: " + std::to_string(image_size_[i]);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
}
|
||||
if (num_classes_ <= 0) {
|
||||
std::string err_msg = "FakeImageNode: num_classes must be greater than 0, but got: " + std::to_string(num_classes_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FakeImageNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_IF_NOT_OK(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
|
||||
auto op = std::make_shared<FakeImageOp>(num_images_, image_size_, num_classes_, base_seed_, num_workers_,
|
||||
connector_que_size_, std::move(schema), std::move(sampler_rt));
|
||||
op->SetTotalRepeats(GetTotalRepeats());
|
||||
op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the shard id of node
|
||||
Status FakeImageNode::GetShardId(int32_t *shard_id) {
|
||||
*shard_id = sampler_->ShardId();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get Dataset size
|
||||
Status FakeImageNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) {
|
||||
if (dataset_size_ > 0) {
|
||||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64_t num_rows, sample_size;
|
||||
num_rows = num_images_;
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
*dataset_size = sample_size;
|
||||
dataset_size_ = *dataset_size;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FakeImageNode::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args, sampler_args;
|
||||
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
|
||||
args["sampler"] = sampler_args;
|
||||
args["num_parallel_workers"] = num_workers_;
|
||||
|
||||
args["num_images"] = num_images_;
|
||||
args["image_size"] = image_size_;
|
||||
args["num_classes"] = num_classes_;
|
||||
args["base_seed"] = base_seed_;
|
||||
if (cache_ != nullptr) {
|
||||
nlohmann::json cache_args;
|
||||
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
|
||||
args["cache"] = cache_args;
|
||||
}
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,101 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FAKE_IMAGE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FAKE_IMAGE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FakeImageNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
FakeImageNode(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes, int32_t base_seed,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~FakeImageNode() = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
std::string Name() const override { return "FakeImageNode"; }
|
||||
|
||||
/// \brief Print the description.
|
||||
/// \param[in] out - The output stream to write output to.
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
/// \brief Copy the node to a new object.
|
||||
/// \return A shared pointer to the new copy.
|
||||
std::shared_ptr<DatasetNode> Copy() override;
|
||||
|
||||
/// \brief A base class override function to create the required runtime dataset op objects for this class.
|
||||
/// \param[in] node_ops - A vector containing shared pointer to the Dataset Ops that this object will create.
|
||||
/// \return Status Status::OK() if build successfully.
|
||||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Parameters validation.
|
||||
/// \return Status Status::OK() if all the parameters are valid.
|
||||
Status ValidateParams() override;
|
||||
|
||||
/// \brief Get the shard id of node.
|
||||
/// \param[in] shard_id - The shard id.
|
||||
/// \return Status Status::OK() if get shard id successfully.
|
||||
Status GetShardId(int32_t *shard_id) override;
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize.
|
||||
/// \param[in] size_getter - Shared pointer to DatasetSizeGetter.
|
||||
/// \param[in] estimate - This is only supported by some of the ops and it's used to speed up the process of getting
|
||||
/// dataset size at the expense of accuracy.
|
||||
/// \param[out] dataset_size - The size of the dataset.
|
||||
/// \return Status of the function.
|
||||
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
|
||||
int64_t *dataset_size) override;
|
||||
|
||||
/// \brief Getter functions.
|
||||
const std::vector<int32_t> &ImageSize() const { return image_size_; }
|
||||
int32_t NumImages() const { return num_images_; }
|
||||
int32_t NumClasses() const { return num_classes_; }
|
||||
int32_t BaseSeed() const { return base_seed_; }
|
||||
|
||||
/// \brief Get the arguments of node.
|
||||
/// \param[out] - out_json JSON string of all attributes.
|
||||
/// \return Status of the function.
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
/// \brief Sampler getter.
|
||||
/// \return SamplerObj of the current node.
|
||||
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
|
||||
|
||||
/// \brief Sampler setter.
|
||||
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
|
||||
|
||||
private:
|
||||
int32_t num_images_;
|
||||
std::vector<int32_t> image_size_;
|
||||
int32_t num_classes_;
|
||||
int32_t base_seed_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_FAKE_IMAGE_NODE_H_
|
|
@ -1715,6 +1715,96 @@ inline std::shared_ptr<EMnistDataset> EMnist(const std::string &dataset_dir, con
|
|||
cache);
|
||||
}
|
||||
|
||||
/// \class FakeImageDataset
|
||||
/// \brief A source dataset for generating fake images.
|
||||
class FakeImageDataset : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor of FakeImageDataset.
|
||||
/// \param[in] num_images The number of images to generate, which must be positive.
|
||||
/// \param[in] image_size Size of the images, which must be a vector of three positive values.
|
||||
/// \param[in] num_classes The number of classes of the images, which must be positive.
|
||||
/// \param[in] base_seed The base seed to generate the images.
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, const std::shared_ptr<Sampler> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of FakeImageDataset.
|
||||
/// \param[in] num_images The number of images to generate, which must be positive.
|
||||
/// \param[in] image_size Size of the images, which must be a vector of three positive values.
|
||||
/// \param[in] num_classes The number of classes of the images, which must be positive.
|
||||
/// \param[in] base_seed The base seed to generate the images.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Constructor of FakeImageDataset.
|
||||
/// \param[in] num_images The number of images to generate, which must be positive.
|
||||
/// \param[in] image_size Size of the images, which must be a vector of three positive values.
|
||||
/// \param[in] num_classes The number of classes of the images, which must be positive.
|
||||
/// \param[in] base_seed The base seed to generate the images.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use.
|
||||
explicit FakeImageDataset(int32_t num_images, const std::vector<int32_t> &image_size, int32_t num_classes,
|
||||
int32_t base_seed, const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// Destructor of FakeImageDataset.
|
||||
~FakeImageDataset() = default;
|
||||
};
|
||||
|
||||
/// \brief Function to create a FakeImageDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] num_images The number of images to generate, which must be positive (default = 1000).
|
||||
/// \param[in] image_size Size of the images, which must be a vector of three positive values
|
||||
/// (default = {224, 224, 3}).
|
||||
/// \param[in] num_classes The number of classes of the images, which must be positive (default = 10).
|
||||
/// \param[in] base_seed The base seed to generate the images (default = 0).
|
||||
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset. If sampler is not
|
||||
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current FakeDataset.
|
||||
inline std::shared_ptr<FakeImageDataset> FakeImage(
|
||||
int32_t num_images = 1000, const std::vector<int32_t> &image_size = {224, 224, 3}, int32_t num_classes = 10,
|
||||
int32_t base_seed = 0, const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a FakeImageDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] num_images The number of images to generate, which must be positive.
|
||||
/// \param[in] image_size Size of the images, which must be a vector of three positive values.
|
||||
/// \param[in] num_classes The number of classes of the images, which must be positive.
|
||||
/// \param[in] base_seed The base seed to generate the images.
|
||||
/// \param[in] sampler Raw pointer to a sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current FakeImageDataset.
|
||||
inline std::shared_ptr<FakeImageDataset> FakeImage(int32_t num_images, const std::vector<int32_t> &image_size,
|
||||
int32_t num_classes, int32_t base_seed, const Sampler *sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
|
||||
}
|
||||
|
||||
/// \brief Function to create a FakeImageDataset.
|
||||
/// \notes The generated dataset has two columns ["image", "label"].
|
||||
/// \param[in] num_images The number of images to generate, which must be positive.
|
||||
/// \param[in] image_size Size of the images, which must be a vector of three positive values.
|
||||
/// \param[in] num_classes The number of classes of the images, which must be positive.
|
||||
/// \param[in] base_seed The base seed to generate the images.
|
||||
/// \param[in] sampler Sampler object used to choose samples from the dataset.
|
||||
/// \param[in] cache Tensor cache to use (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current FakeImageDataset.
|
||||
inline std::shared_ptr<FakeImageDataset> FakeImage(int32_t num_images, const std::vector<int32_t> &image_size,
|
||||
int32_t num_classes, int32_t base_seed,
|
||||
const std::reference_wrapper<Sampler> sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<FakeImageDataset>(num_images, image_size, num_classes, base_seed, sampler, cache);
|
||||
}
|
||||
|
||||
/// \class FlickrDataset
|
||||
/// \brief A source dataset for reading and parsing Flickr dataset.
|
||||
class FlickrDataset : public Dataset {
|
||||
|
|
|
@ -40,6 +40,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
friend class CSVDataset;
|
||||
friend class DIV2KDataset;
|
||||
friend class EMnistDataset;
|
||||
friend class FakeImageDataset;
|
||||
friend class FlickrDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class ManifestDataset;
|
||||
|
|
|
@ -66,7 +66,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
|
|||
check_bucket_batch_by_length, check_cluedataset, check_save, check_csvdataset, check_paddeddataset, \
|
||||
check_tuple_iterator, check_dict_iterator, check_schema, check_to_device_send, check_flickr_dataset, \
|
||||
check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, check_usps_dataset, check_div2k_dataset, \
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset
|
||||
check_sbu_dataset, check_qmnist_dataset, check_emnist_dataset, check_fake_image_dataset
|
||||
from ..core.config import get_callback_timeout, _init_device_info, get_enable_shared_mem, get_num_parallel_workers, \
|
||||
get_prefetch_size
|
||||
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
|
||||
|
@ -6482,6 +6482,95 @@ class EMnistDataset(MappableDataset):
|
|||
return cde.EMnistNode(self.dataset_dir, self.name, self.usage, self.sampler)
|
||||
|
||||
|
||||
class FakeImageDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for generating fake images.
|
||||
|
||||
The generated dataset has two columns :py:obj:`[image, label]`.
|
||||
The tensor of column :py:obj:`image` is of the uint8 type.
|
||||
The tensor of column :py:obj:`label` is a scalar of the uint32 type.
|
||||
|
||||
Args:
|
||||
num_images (int, optional): Number of images to generate in the dataset (default=1000).
|
||||
image_size (tuple, optional): Size of the fake image (default=(224, 224, 3)).
|
||||
num_classes (int, optional): Number of classes in the dataset (default=10).
|
||||
base_seed (int, optional): Offsets the index-based random seed used to generate each image (default=0).
|
||||
num_samples (int, optional): The number of images to be included in the dataset
|
||||
(default=None, will read all images).
|
||||
num_parallel_workers (int, optional): Number of workers to read the data
|
||||
(default=None, will use value set in the config).
|
||||
shuffle (bool, optional): Whether or not to perform shuffle on the dataset
|
||||
(default=None, expected order behavior shown in the table).
|
||||
sampler (Sampler, optional): Object used to choose samples from the
|
||||
dataset (default=None, expected order behavior shown in the table).
|
||||
num_shards (int, optional): Number of shards that the dataset will be divided into (default=None).
|
||||
When this argument is specified, `num_samples` reflects the max sample number of per shard.
|
||||
shard_id (int, optional): The shard ID within `num_shards` (default=None). This
|
||||
argument can only be specified when `num_shards` is also specified.
|
||||
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing.
|
||||
(default=None, which means no cache is used).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If num_parallel_workers exceeds the max thread numbers.
|
||||
RuntimeError: If sampler and shuffle are specified at the same time.
|
||||
RuntimeError: If sampler and sharding are specified at the same time.
|
||||
RuntimeError: If num_shards is specified but shard_id is None.
|
||||
RuntimeError: If shard_id is specified but num_shards is None.
|
||||
ValueError: If shard_id is invalid (< 0 or >= num_shards).
|
||||
|
||||
Note:
|
||||
- This dataset can take in a sampler. 'sampler' and 'shuffle' are mutually exclusive.
|
||||
The table below shows what input arguments are allowed and their expected behavior.
|
||||
|
||||
.. list-table:: Expected Order Behavior of Using 'sampler' and 'shuffle'
|
||||
:widths: 25 25 50
|
||||
:header-rows: 1
|
||||
|
||||
* - Parameter 'sampler'
|
||||
- Parameter 'shuffle'
|
||||
- Expected Order Behavior
|
||||
* - None
|
||||
- None
|
||||
- random order
|
||||
* - None
|
||||
- True
|
||||
- random order
|
||||
* - None
|
||||
- False
|
||||
- sequential order
|
||||
* - Sampler object
|
||||
- None
|
||||
- order defined by sampler
|
||||
* - Sampler object
|
||||
- True
|
||||
- not allowed
|
||||
* - Sampler object
|
||||
- False
|
||||
- not allowed
|
||||
|
||||
Examples:
|
||||
>>> # Read 3 samples from FakeImage dataset
|
||||
>>> dataset = ds.FakeImageDataset(num_images=1000, image_size=(224,224,3),
|
||||
... num_classes=10, base_seed=0, num_samples=3)
|
||||
>>>
|
||||
>>> # Note: In FakeImage dataset, each dictionary has keys "image" and "label"
|
||||
"""
|
||||
|
||||
@check_fake_image_dataset
|
||||
def __init__(self, num_images=1000, image_size=(224, 224, 3), num_classes=10, base_seed=0, num_samples=None,
|
||||
num_parallel_workers=None, shuffle=None, sampler=None, num_shards=None, shard_id=None, cache=None):
|
||||
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
|
||||
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
|
||||
|
||||
self.num_images = num_images
|
||||
self.image_size = image_size
|
||||
self.num_classes = num_classes
|
||||
self.base_seed = base_seed
|
||||
|
||||
def parse(self, children=None):
|
||||
return cde.FakeImageNode(self.num_images, self.image_size, self.num_classes, self.base_seed, self.sampler)
|
||||
|
||||
|
||||
class FlickrDataset(MappableDataset):
|
||||
"""
|
||||
A source dataset for reading and parsing Flickr8k and Flickr30k dataset.
|
||||
|
|
|
@ -1631,3 +1631,40 @@ def check_div2k_dataset(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_fake_image_dataset(method):
|
||||
"""A wrapper that wraps a parameter checker around the original Dataset(FakeImageDataset)."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
_, param_dict = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
nreq_param_int = ['num_images', 'num_classes', 'base_seed', 'num_samples',
|
||||
'num_parallel_workers', 'num_shards', 'shard_id']
|
||||
nreq_param_bool = ['shuffle']
|
||||
|
||||
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
||||
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
||||
|
||||
num_images = param_dict.get("num_images")
|
||||
check_pos_int32(num_images, "num_images")
|
||||
|
||||
image_size = param_dict.get("image_size")
|
||||
type_check(image_size, (list, tuple), "image_size")
|
||||
if len(image_size) != 3:
|
||||
raise ValueError("image_size should be a list or tuple of length 3, but got {0}".format(len(image_size)))
|
||||
for i, value in enumerate(image_size):
|
||||
check_pos_int32(value, "image_size[{0}]".format(i))
|
||||
|
||||
num_classes = param_dict.get("num_classes")
|
||||
check_pos_int32(num_classes, "num_classes")
|
||||
|
||||
check_sampler_shuffle_shard_options(param_dict)
|
||||
|
||||
cache = param_dict.get('cache')
|
||||
check_cache_option(cache)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -24,6 +24,7 @@ SET(DE_UT_SRCS
|
|||
c_api_dataset_csv_test.cc
|
||||
c_api_dataset_div2k_test.cc
|
||||
c_api_dataset_emnist_test.cc
|
||||
c_api_dataset_fake_image_test.cc
|
||||
c_api_dataset_flickr_test.cc
|
||||
c_api_dataset_iterator_test.cc
|
||||
c_api_dataset_manifest_test.cc
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
|
||||
#include "minddata/dataset/include/dataset/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::DataType;
|
||||
using mindspore::dataset::Tensor;
|
||||
using mindspore::dataset::TensorShape;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test FakeImage
|
||||
/// Expectation: get correct FakeImage dataset
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDataset) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDataset.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(50, {28, 28, 3}, 3, 0, std::make_shared<RandomSampler>(false, 10));
|
||||
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test FakeImage in pipeline mode
|
||||
/// Expectation: get correct FakeImage dataset
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDatasetWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDatasetWithPipeline.";
|
||||
|
||||
// Create two FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds1 = FakeImage(50, {28, 28, 3}, 3, 0, std::make_shared<RandomSampler>(false, 10));
|
||||
std::shared_ptr<Dataset> ds2 = FakeImage(50, {28, 28, 3}, 3, 0, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Repeat operation on ds
|
||||
int32_t repeat_num = 2;
|
||||
ds1 = ds1->Repeat(repeat_num);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
repeat_num = 2;
|
||||
ds2 = ds2->Repeat(repeat_num);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create two Project operation on ds
|
||||
std::vector<std::string> column_project = {"image", "label"};
|
||||
ds1 = ds1->Project(column_project);
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
ds2 = ds2->Project(column_project);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
|
||||
// Create a Concat operation on the ds
|
||||
ds1 = ds1->Concat({ds2});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
EXPECT_NE(row.find("image"), row.end());
|
||||
EXPECT_NE(row.find("label"), row.end());
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
auto image = row["image"];
|
||||
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, 40);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test GetDataSize of FakeImage
|
||||
/// Expectation: get the correct size of FakeImage
|
||||
TEST_F(MindDataTestPipeline, TestGetFakeImageDatasetSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestGetFakeImageDatasetSize.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(50, {28, 28, 3}, 3, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 50);
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test DatasetGetters of FakeImage
|
||||
/// Expectation: getters of FakeImage get the correct value
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDatasetGetters) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDatasetGetters.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(50, {28, 28, 3}, 3, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 50);
|
||||
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
|
||||
std::vector<TensorShape> shapes = ToTensorShapeVec(ds->GetOutputShapes());
|
||||
std::vector<std::string> column_names = {"image", "label"};
|
||||
int64_t num_classes = ds->GetNumClasses();
|
||||
EXPECT_EQ(types.size(), 2);
|
||||
EXPECT_EQ(types[0].ToString(), "uint8");
|
||||
EXPECT_EQ(types[1].ToString(), "uint32");
|
||||
EXPECT_EQ(shapes.size(), 2);
|
||||
EXPECT_EQ(shapes[0].ToString(), "<28,28,3>");
|
||||
EXPECT_EQ(shapes[1].ToString(), "<>");
|
||||
EXPECT_EQ(num_classes, -1);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 50);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
|
||||
EXPECT_EQ(ds->GetColumnNames(), column_names);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 50);
|
||||
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
|
||||
EXPECT_EQ(ToTensorShapeVec(ds->GetOutputShapes()), shapes);
|
||||
EXPECT_EQ(ds->GetBatchSize(), 1);
|
||||
EXPECT_EQ(ds->GetRepeatCount(), 1);
|
||||
EXPECT_EQ(ds->GetNumClasses(), -1);
|
||||
EXPECT_EQ(ds->GetDatasetSize(), 50);
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test invalid num_images of FakeImage
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDatasetWithInvalidNumImages) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDatasetWithInvalidNumImages.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(-1, {28, 28, 3}, 3, 0, std::make_shared<RandomSampler>(false, 10));
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FakeImage input
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test invalid image_size of FakeImage
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDatasetWithInvalidImageSize) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDatasetWithInvalidImageSize.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(50, {-1, -1, -1}, 3, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FakeImageD input, {-1,-1,-1} is not a valid imagesize
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test invalid num_classes of FakeImage
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDatasetWithInvalidNumClasses) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDatasetWithInvalidNumClasses.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(50, {28, 28, 3}, -1, 0);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FakeImage input, -1 is not a valid num class
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: FakeIamge
|
||||
/// Description: test FakeImage dataset with null sampler
|
||||
/// Expectation: dataset is null
|
||||
TEST_F(MindDataTestPipeline, TestFakeImageDatasetWithNullSampler) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFakeImageDatasetWithNullSampler.";
|
||||
|
||||
// Create a FakeImage Dataset
|
||||
std::shared_ptr<Dataset> ds = FakeImage(50, {28, 28, 3}, 3, 0, nullptr);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure: invalid FakeImage input, sampler cannot be nullptr
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -0,0 +1,303 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Test FakeImage dataset operators
|
||||
"""
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
num_images = 50
|
||||
image_size = (28, 28, 3)
|
||||
num_classes = 10
|
||||
base_seed = 0
|
||||
|
||||
|
||||
def visualize_dataset(images, labels):
|
||||
"""
|
||||
Helper function to visualize the dataset samples
|
||||
"""
|
||||
num_samples = len(images)
|
||||
for i in range(num_samples):
|
||||
plt.subplot(1, num_samples, i + 1)
|
||||
plt.imshow(images[i].squeeze(), cmap=plt.cm.gray)
|
||||
plt.title(labels[i])
|
||||
plt.show()
|
||||
|
||||
|
||||
def test_fake_image_basic():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test basic usage of FakeImage
|
||||
Expectation: the dataset is as expected
|
||||
"""
|
||||
logger.info("Test FakeImageDataset Op")
|
||||
|
||||
# case 1: test loading whole dataset
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed)
|
||||
num_iter1 = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter1 += 1
|
||||
assert num_iter1 == num_images
|
||||
|
||||
# case 2: test num_samples
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
|
||||
num_iter2 = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter2 += 1
|
||||
assert num_iter2 == 4
|
||||
|
||||
# case 3: test repeat
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
|
||||
train_data = train_data.repeat(5)
|
||||
num_iter3 = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter3 += 1
|
||||
assert num_iter3 == 20
|
||||
|
||||
# case 4: test batch with drop_remainder=False, get_dataset_size, get_batch_size, get_col_names
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
|
||||
assert train_data.get_dataset_size() == 4
|
||||
assert train_data.get_batch_size() == 1
|
||||
assert train_data.get_col_names() == ['image', 'label']
|
||||
train_data = train_data.batch(batch_size=3) # drop_remainder is default to be False
|
||||
assert train_data.get_dataset_size() == 2
|
||||
assert train_data.get_batch_size() == 3
|
||||
num_iter4 = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter4 += 1
|
||||
assert num_iter4 == 2
|
||||
|
||||
# case 5: test batch with drop_remainder=True
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=4)
|
||||
assert train_data.get_dataset_size() == 4
|
||||
assert train_data.get_batch_size() == 1
|
||||
train_data = train_data.batch(batch_size=3, drop_remainder=True) # the rest of incomplete batch will be dropped
|
||||
assert train_data.get_dataset_size() == 1
|
||||
assert train_data.get_batch_size() == 3
|
||||
num_iter5 = 0
|
||||
for _ in train_data.create_dict_iterator(num_epochs=1):
|
||||
num_iter5 += 1
|
||||
assert num_iter5 == 1
|
||||
|
||||
|
||||
def test_fake_image_pk_sampler():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test FakeImageDataset with PKSamplere
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
logger.info("Test FakeImageDataset Op with PKSampler")
|
||||
golden = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9]
|
||||
#correlation with num_classes
|
||||
sampler = ds.PKSampler(3)
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, sampler=sampler)
|
||||
num_iter = 0
|
||||
label_list = []
|
||||
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
label_list.append(item["label"])
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(golden, label_list)
|
||||
assert num_iter == 30
|
||||
|
||||
|
||||
def test_fake_image_sequential_sampler():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test FakeImageDataset with SequentialSampler
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
logger.info("Test FakeImageDataset Op with SequentialSampler")
|
||||
num_samples = 50
|
||||
sampler = ds.SequentialSampler(num_samples=num_samples)
|
||||
train_data1 = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, sampler=sampler)
|
||||
train_data2 = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False,
|
||||
num_samples=num_samples)
|
||||
|
||||
label_list1, label_list2 = [], []
|
||||
num_iter = 0
|
||||
for item1, item2 in zip(train_data1.create_dict_iterator(num_epochs=1),
|
||||
train_data2.create_dict_iterator(num_epochs=1)):
|
||||
label_list1.append(item1["label"].asnumpy())
|
||||
label_list2.append(item2["label"].asnumpy())
|
||||
num_iter += 1
|
||||
np.testing.assert_array_equal(label_list1, label_list2)
|
||||
assert num_iter == num_samples
|
||||
|
||||
|
||||
def test_fake_image_exception():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test error cases for FakeImageDataset
|
||||
Expectation: throw exception correctly
|
||||
"""
|
||||
logger.info("Test error cases for FakeImageDataset")
|
||||
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_1):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, sampler=ds.PKSampler(3))
|
||||
|
||||
error_msg_2 = "sampler and sharding cannot be specified at the same time"
|
||||
with pytest.raises(RuntimeError, match=error_msg_2):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, sampler=ds.PKSampler(3), num_shards=2,
|
||||
shard_id=0)
|
||||
|
||||
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
|
||||
with pytest.raises(RuntimeError, match=error_msg_3):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=10)
|
||||
|
||||
error_msg_4 = "shard_id is specified but num_shards is not"
|
||||
with pytest.raises(RuntimeError, match=error_msg_4):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shard_id=0)
|
||||
|
||||
error_msg_5 = "Input shard_id is not within the required interval"
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=5, shard_id=-1)
|
||||
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=5, shard_id=5)
|
||||
|
||||
with pytest.raises(ValueError, match=error_msg_5):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=2, shard_id=5)
|
||||
|
||||
error_msg_6 = "num_parallel_workers exceeds"
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, num_parallel_workers=0)
|
||||
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, num_parallel_workers=256)
|
||||
|
||||
with pytest.raises(ValueError, match=error_msg_6):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, shuffle=False, num_parallel_workers=-2)
|
||||
|
||||
error_msg_7 = "Argument shard_id"
|
||||
with pytest.raises(TypeError, match=error_msg_7):
|
||||
ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_shards=2, shard_id="0")
|
||||
|
||||
|
||||
def test_fake_image_visualize(plot=False):
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test FakeImageDataset visualized results
|
||||
Expectation: get correct dataset of FakeImage
|
||||
"""
|
||||
logger.info("Test FakeImageDataset visualization")
|
||||
|
||||
train_data = ds.FakeImageDataset(num_images, image_size, num_classes, base_seed, num_samples=10, shuffle=False)
|
||||
num_iter = 0
|
||||
image_list, label_list = [], []
|
||||
for item in train_data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = item["image"]
|
||||
label = item["label"]
|
||||
image_list.append(image)
|
||||
label_list.append("label {}".format(label))
|
||||
assert isinstance(image, np.ndarray)
|
||||
assert image.shape == (28, 28, 3)
|
||||
assert image.dtype == np.uint8
|
||||
assert label.dtype == np.uint32
|
||||
num_iter += 1
|
||||
assert num_iter == 10
|
||||
if plot:
|
||||
visualize_dataset(image_list, label_list)
|
||||
|
||||
|
||||
def test_fake_image_num_images():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test FakeImageDataset with num images
|
||||
Expectation: throw exception correctly or get correct dataset
|
||||
"""
|
||||
logger.info("Test FakeImageDataset num_images flag")
|
||||
|
||||
def test_config(test_num_images):
|
||||
|
||||
try:
|
||||
data = ds.FakeImageDataset(test_num_images, image_size, num_classes, base_seed, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config(num_images) == num_images
|
||||
|
||||
assert "Input num_images is not within the required interval of [1, 2147483647]." in test_config(-1)
|
||||
assert "is not of type [<class 'int'>], but got <class 'str'>." in test_config("10")
|
||||
|
||||
|
||||
def test_fake_image_image_size():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test FakeImageDataset with image size
|
||||
Expectation: throw exception correctly or get correct dataset
|
||||
"""
|
||||
logger.info("Test FakeImageDataset image_size flag")
|
||||
|
||||
def test_config(test_image_size):
|
||||
try:
|
||||
data = ds.FakeImageDataset(num_images, test_image_size, num_classes, base_seed, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config(image_size) == num_images
|
||||
|
||||
assert "Argument image_size[0] with value -1 is not of type [<class 'int'>], but got <class 'str'>."\
|
||||
in test_config(("-1", 28, 3))
|
||||
assert "image_size should be a list or tuple of length 3, but got 2" in test_config((2, 2))
|
||||
assert "Input image_size[0] is not within the required interval of [1, 2147483647]." in test_config((-1, 28, 3))
|
||||
|
||||
|
||||
def test_fake_image_num_classes():
|
||||
"""
|
||||
Feature: FakeImage
|
||||
Description: test FakeImageDataset with num classes
|
||||
Expectation: throw exception correctly or get correct dataset
|
||||
"""
|
||||
logger.info("Test FakeImageDataset num_classes flag")
|
||||
|
||||
def test_config(test_num_classes):
|
||||
try:
|
||||
data = ds.FakeImageDataset(num_images, image_size, test_num_classes, base_seed, shuffle=False)
|
||||
num_rows = 0
|
||||
for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
num_rows += 1
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
return str(e)
|
||||
return num_rows
|
||||
|
||||
assert test_config(num_classes) == num_images
|
||||
|
||||
assert "Input num_classes is not within the required interval of [1, 2147483647]." in test_config(-1)
|
||||
#should not be negative
|
||||
assert "is not of type [<class 'int'>], but got <class 'str'>." in test_config("10")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fake_image_basic()
|
||||
test_fake_image_pk_sampler()
|
||||
test_fake_image_sequential_sampler()
|
||||
test_fake_image_exception()
|
||||
test_fake_image_visualize(plot=True)
|
||||
test_fake_image_num_images()
|
||||
test_fake_image_image_size()
|
||||
test_fake_image_num_classes()
|
Loading…
Reference in New Issue