[feat][assistant][I3J6VO] add new data operator KITTI

This commit is contained in:
zx 2021-12-16 11:38:55 +08:00
parent ff49911889
commit 93617ce91e
26 changed files with 1625 additions and 1 deletions

View File

@ -98,6 +98,7 @@
#include "minddata/dataset/engine/ir/datasetops/source/imdb_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2016_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/iwslt2017_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/kitti_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/kmnist_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lfw_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/libri_tts_node.h"
@ -1375,6 +1376,27 @@ IWSLT2017Dataset::IWSLT2017Dataset(const std::vector<char> &dataset_dir, const s
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
KITTIDataset::KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<KITTINode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
KITTIDataset::KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;
auto ds = std::make_shared<KITTINode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
KITTIDataset::KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler.get().Parse();
auto ds = std::make_shared<KITTINode>(CharToString(dataset_dir), CharToString(usage), decode, sampler_obj, cache);
ir_node_ = std::static_pointer_cast<DatasetNode>(ds);
}
KMnistDataset::KMnistDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache) {
auto sampler_obj = sampler ? sampler->Parse() : nullptr;

View File

@ -69,6 +69,7 @@
// IR leaf nodes disabled for android
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/ir/datasetops/source/kitti_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lj_speech_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/lsun_node.h"
#include "minddata/dataset/engine/ir/datasetops/source/manifest_node.h"
@ -408,6 +409,18 @@ PYBIND_REGISTER(IWSLT2017Node, 2, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(KITTINode, 2, ([](const py::module *m) {
(void)py::class_<KITTINode, DatasetNode, std::shared_ptr<KITTINode>>(*m, "KITTINode",
"to create a KITTINode")
.def(py::init([](const std::string &dataset_dir, const std::string &usage, bool decode,
const py::handle &sampler) {
std::shared_ptr<KITTINode> kitti =
std::make_shared<KITTINode>(dataset_dir, usage, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(kitti->ValidateParams());
return kitti;
}));
}));
PYBIND_REGISTER(KMnistNode, 2, ([](const py::module *m) {
(void)py::class_<KMnistNode, DatasetNode, std::shared_ptr<KMnistNode>>(*m, "KMnistNode",
"to create a KMnistNode")

View File

@ -27,6 +27,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
imdb_op.cc
iwslt_op.cc
io_block.cc
kitti_op.cc
kmnist_op.cc
lfw_op.cc
libri_tts_op.cc

View File

@ -0,0 +1,319 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/datasetops/source/kitti_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "utils/file_utils.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace dataset {
constexpr int kLabelNameIndex = 0;
constexpr int kTruncatedIndex = 1;
constexpr int kOccludedIndex = 2;
constexpr int kAlphaIndex = 3;
constexpr int kXMinIndex = 4;
constexpr int kYMinIndex = 5;
constexpr int kXMaxIndex = 6;
constexpr int kYMaxIndex = 7;
constexpr int kFirstDimensionIndex = 8;
constexpr int kSecondDimensionIndex = 9;
constexpr int kThirdDimensionIndex = 10;
constexpr int kFirstLocationIndex = 11;
constexpr int kSecondLocationIndex = 12;
constexpr int kThirdLocationIndex = 13;
constexpr int kRotationYIndex = 14;
constexpr int kTotalParamNums = 14;
const char kImagesFolder[] = "data_object_image_2";
const char kAnnotationsFolder[] = "data_object_label_2";
const char kImageExtension[] = ".png";
const char kAnnotationExtension[] = ".txt";
const int32_t kKittiFileNameLength = 6;
KITTIOp::KITTIOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, int32_t queue_size,
bool decode, std::unique_ptr<DataSchema> data_schema, const std::shared_ptr<SamplerRT> &sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
decode_(decode),
usage_(usage),
folder_path_(dataset_dir),
data_schema_(std::move(data_schema)) {}
void KITTIOp::Print(std::ostream &out, bool show_all) const {
if (!show_all) {
// Call the super class for displaying any common 1-liner info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op.
out << "\n";
} else {
// Call the super class for displaying any common detailed info.
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff.
out << "\nNumber of rows: " << num_rows_ << "\nKITTI directory: " << folder_path_
<< "\nDecode: " << (decode_ ? "yes" : "no") << "\n\n";
}
}
Status KITTIOp::LoadTensorRow(row_id_type row_id, TensorRow *trow) {
RETURN_UNEXPECTED_IF_NULL(trow);
std::string image_id = image_ids_[row_id];
std::shared_ptr<Tensor> image;
auto realpath = FileUtils::GetRealPath(folder_path_.c_str());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, KITTI Dataset dir: " << folder_path_ << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, KITTI Dataset dir: " + folder_path_ + " does not exist.");
}
Path path(realpath.value());
if (usage_ == "train") {
TensorRow annotation;
Path kImageFile = path / kImagesFolder / "training" / "image_2" / (image_id + kImageExtension);
Path kAnnotationFile = path / kAnnotationsFolder / "training" / "label_2" / (image_id + kAnnotationExtension);
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile.ToString(), data_schema_->Column(0), &image));
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile.ToString(), &annotation));
trow->setId(row_id);
trow->setPath({kImageFile.ToString(), kAnnotationFile.ToString(), kAnnotationFile.ToString(),
kAnnotationFile.ToString(), kAnnotationFile.ToString(), kAnnotationFile.ToString(),
kAnnotationFile.ToString(), kAnnotationFile.ToString(), kAnnotationFile.ToString()});
trow->push_back(std::move(image));
trow->insert(trow->end(), annotation.begin(), annotation.end());
} else if (usage_ == "test") {
Path kImageFile = path / kImagesFolder / "testing" / "image_2" / (image_id + kImageExtension);
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile.ToString(), data_schema_->Column(0), &image));
trow->setId(row_id);
trow->setPath({kImageFile.ToString()});
trow->push_back(std::move(image));
}
return Status::OK();
}
Status KITTIOp::ParseImageIds() {
if (!image_ids_.empty()) {
return Status::OK();
}
auto folder_realpath = FileUtils::GetRealPath(folder_path_.c_str());
if (!folder_realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, KITTI Dataset dir: " << folder_path_ << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, KITTI Dataset dir: " + folder_path_ + " does not exist.");
}
Path path(folder_realpath.value());
Path image_sets_file("");
if (usage_ == "train") {
image_sets_file = path / kImagesFolder / "training" / "image_2";
} else if (usage_ == "test") {
image_sets_file = path / kImagesFolder / "testing" / "image_2";
}
std::shared_ptr<Path::DirIterator> dirItr = Path::DirIterator::OpenDirectory(&image_sets_file);
CHECK_FAIL_RETURN_UNEXPECTED(dirItr != nullptr, "Invalid path, failed to open KITTI image dir: " +
image_sets_file.ToString() + ", permission denied.");
int32_t total_image_size = 0;
while (dirItr->HasNext()) {
total_image_size++;
}
std::string format_id;
for (int32_t i = 0; i < total_image_size; ++i) {
format_id = "";
std::string id = std::to_string(i);
for (int32_t j = 0; j < kKittiFileNameLength - id.size(); ++j) {
format_id = format_id + std::string("0");
}
image_ids_.push_back(format_id + id);
}
image_ids_.shrink_to_fit();
num_rows_ = image_ids_.size();
return Status::OK();
}
Status KITTIOp::ParseAnnotationIds() {
std::vector<std::string> new_image_ids;
auto realpath = FileUtils::GetRealPath(folder_path_.c_str());
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Invalid file path, KITTI Dataset dir: " << folder_path_ << " does not exist.";
RETURN_STATUS_UNEXPECTED("Invalid file path, KITTI Dataset dir: " + folder_path_ + " does not exist.");
}
Path path(realpath.value());
for (auto id : image_ids_) {
Path kAnnotationName = path / kAnnotationsFolder / "training" / "label_2" / (id + kAnnotationExtension);
RETURN_IF_NOT_OK(ParseAnnotationBbox(kAnnotationName.ToString()));
if (annotation_map_.find(kAnnotationName.ToString()) != annotation_map_.end()) {
new_image_ids.push_back(id);
}
}
if (image_ids_.size() != new_image_ids.size()) {
image_ids_.clear();
image_ids_.insert(image_ids_.end(), new_image_ids.begin(), new_image_ids.end());
}
uint32_t count = 0;
for (auto &label : label_index_) {
label.second = count++;
}
num_rows_ = image_ids_.size();
if (num_rows_ == 0) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, no valid data matching the dataset API KITTIDataset. Please check file path or dataset API.");
}
return Status::OK();
}
Status KITTIOp::ParseAnnotationBbox(const std::string &path) {
CHECK_FAIL_RETURN_UNEXPECTED(Path(path).Exists(), "Invalid path, " + path + " does not exist.");
Annotation annotation;
std::ifstream in_file;
in_file.open(path);
if (in_file.fail()) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open file: " + path);
}
std::string anno;
while (getline(in_file, anno)) {
std::string label_name;
std::string line_result;
std::vector<std::string> vector_string;
float truncated = 0.0, occluded = 0.0, alpha = 0.0, xmin = 0.0, ymin = 0.0, xmax = 0.0, ymax = 0.0,
first_dimension = 0.0, second_dimension = 0.0, third_dimension = 0.0, first_location = 0.0,
second_location = 0.0, third_location = 0.0, rotation_y = 0.0;
std::stringstream line(anno);
while (line >> line_result) {
vector_string.push_back(line_result);
}
label_name = vector_string[kLabelNameIndex];
truncated = std::atof(vector_string[kTruncatedIndex].c_str());
occluded = std::atof(vector_string[kOccludedIndex].c_str());
alpha = std::atof(vector_string[kAlphaIndex].c_str());
xmin = std::atof(vector_string[kXMinIndex].c_str());
ymin = std::atof(vector_string[kYMinIndex].c_str());
xmax = std::atof(vector_string[kXMaxIndex].c_str());
ymax = std::atof(vector_string[kYMaxIndex].c_str());
first_dimension = std::atof(vector_string[kFirstDimensionIndex].c_str());
second_dimension = std::atof(vector_string[kSecondDimensionIndex].c_str());
third_dimension = std::atof(vector_string[kThirdDimensionIndex].c_str());
first_location = std::atof(vector_string[kFirstLocationIndex].c_str());
second_location = std::atof(vector_string[kSecondLocationIndex].c_str());
third_location = std::atof(vector_string[kThirdLocationIndex].c_str());
rotation_y = std::atof(vector_string[kRotationYIndex].c_str());
if (label_name != "" || (xmin > 0 && ymin > 0 && xmax > xmin && ymax > ymin)) {
std::vector<float> bbox_list = {truncated,
occluded,
alpha,
xmin,
ymin,
xmax,
ymax,
first_dimension,
second_dimension,
third_dimension,
first_location,
second_location,
third_location,
rotation_y};
annotation.emplace_back(std::make_pair(label_name, bbox_list));
label_index_[label_name] = 0;
}
}
in_file.close();
if (annotation.size() > 0) {
annotation_map_[path] = annotation;
}
return Status::OK();
}
Status KITTIOp::PrepareData() {
RETURN_IF_NOT_OK(this->ParseImageIds());
if (usage_ == "train") {
RETURN_IF_NOT_OK(this->ParseAnnotationIds());
}
return Status::OK();
}
Status KITTIOp::ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor) {
RETURN_UNEXPECTED_IF_NULL(tensor);
RETURN_IF_NOT_OK(Tensor::CreateFromFile(path, tensor));
if (decode_) {
Status rc = Decode(*tensor, tensor);
if (rc.IsError()) {
RETURN_STATUS_UNEXPECTED("Invalid data, failed to decode image: " + path);
}
}
return Status::OK();
}
Status KITTIOp::ReadAnnotationToTensor(const std::string &path, TensorRow *row) {
RETURN_UNEXPECTED_IF_NULL(row);
Annotation annotation = annotation_map_[path];
std::shared_ptr<Tensor> bbox, alpha, dimensions, location, occuluded, rotation_y, truncated, label;
std::vector<float> bbox_data, alpha_data, dimensions_data, location_data, rotation_y_data, truncated_data;
std::vector<uint32_t> occuluded_data;
std::vector<uint32_t> label_data;
dsize_t bbox_num = 0;
for (auto item : annotation) {
if (label_index_.find(item.first) != label_index_.end()) {
label_data.push_back(static_cast<uint32_t>(label_index_[item.first]));
CHECK_FAIL_RETURN_UNEXPECTED(item.second.size() == kTotalParamNums,
"Invalid file, the format of the annotation file is not as expected, got " +
std::to_string(item.second.size()) + " parameters.");
std::vector<float> tmp_bbox = {(item.second)[3], (item.second)[4], (item.second)[5], (item.second)[6]};
std::vector<float> tmp_dimensions = {(item.second)[7], (item.second)[8], (item.second)[9]};
std::vector<float> tmp_location = {(item.second)[10], (item.second)[11], (item.second)[12]};
bbox_data.insert(bbox_data.end(), tmp_bbox.begin(), tmp_bbox.end());
dimensions_data.insert(dimensions_data.end(), tmp_dimensions.begin(), tmp_dimensions.end());
location_data.insert(location_data.end(), tmp_location.begin(), tmp_location.end());
alpha_data.push_back(static_cast<float>((item.second)[kAlphaIndex]));
truncated_data.push_back(static_cast<float>((item.second)[0]));
occuluded_data.push_back(static_cast<uint32_t>(int64_t((item.second)[1])));
rotation_y_data.push_back(static_cast<float>((item.second)[kRotationYIndex]));
bbox_num++;
}
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(label_data, TensorShape({bbox_num, 1}), &label));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(truncated_data, TensorShape({bbox_num, 1}), &truncated));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(occuluded_data, TensorShape({bbox_num, 1}), &occuluded));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(alpha_data, TensorShape({bbox_num, 1}), &alpha));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(bbox_data, TensorShape({bbox_num, 4}), &bbox));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(dimensions_data, TensorShape({bbox_num, 3}), &dimensions));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(location_data, TensorShape({bbox_num, 3}), &location));
RETURN_IF_NOT_OK(Tensor::CreateFromVector(rotation_y_data, TensorShape({bbox_num, 1}), &rotation_y));
(*row) = TensorRow({std::move(label), std::move(truncated), std::move(occuluded), std::move(alpha), std::move(bbox),
std::move(dimensions), std::move(location), std::move(rotation_y)});
return Status::OK();
}
Status KITTIOp::CountTotalRows(int64_t *count) {
RETURN_UNEXPECTED_IF_NULL(count);
RETURN_IF_NOT_OK(PrepareData());
*count = static_cast<int64_t>(image_ids_.size());
return Status::OK();
}
Status KITTIOp::ComputeColMap() {
// Set the column name map (base class field).
if (column_name_id_map_.empty()) {
for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) {
column_name_id_map_[data_schema_->Column(i).Name()] = i;
}
} else {
MS_LOG(WARNING) << "Column name map is already set!";
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,127 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_KITTI_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_KITTI_OP_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/mappable_leaf_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
// Forward declares
template <typename T>
class Queue;
using Annotation = std::vector<std::pair<std::string, std::vector<float>>>;
/// \class KITTIOp
/// \brief A source dataset for reading and parsing KITTI dataset.
class KITTIOp : public MappableLeafOp {
public:
// Constructor
// @param std::string dataset_dir - dir directory of KITTI.
// @param std::string usage - split of KITTI.
// @param int32_t num_workers - number of workers reading images in parallel.
// @param int32_t queue_size - connector queue size.
// @param bool decode - whether to decode images.
// @param std::unique_ptr<DataSchema> data_schema - the schema of the KITTI dataset.
// @param std::shared_ptr<Sampler> sampler - sampler tells KITTIOp what to read.
KITTIOp(const std::string &dataset_dir, const std::string &usage, int32_t num_workers, int32_t queue_size,
bool decode, std::unique_ptr<DataSchema> data_schema, const std::shared_ptr<SamplerRT> &sampler);
// Destructor.
~KITTIOp() = default;
// A print method typically used for debugging.
// @param out - The output stream to write output to.
// @param show_all - A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
// Function to count the number of samples in the KITTIDataset.
// @param int64_t *count - output rows number of KITTIDataset.
Status CountTotalRows(int64_t *count);
// Op name getter.
// @return Name of the current Op.
std::string Name() const override { return "KITTIOp"; }
private:
// Load a tensor row according to image id.
// @param row_id_type row_id - id for this tensor row.
// @param TensorRow *row - image & target read into this tensor row.
// @return Status The status code returned.
Status LoadTensorRow(row_id_type row_id, TensorRow *row) override;
// Load an image to Tensor.
// @param const std::string &path - path to the image file.
// @param const ColDescriptor &col - contains tensor implementation and datatype.
// @param std::shared_ptr<Tensor> *tensor - return image tensor.
// @return Status The status code returned.
Status ReadImageToTensor(const std::string &path, const ColDescriptor &col, std::shared_ptr<Tensor> *tensor);
// Load an annotation to Tensor.
// @param const std::string &path - path to the image file.
// @param TensorRow *row - return annotation tensor.
// @return Status The status code returned.
Status ReadAnnotationToTensor(const std::string &path, TensorRow *row);
// Read image list from ImageSets.
// @return Status The status code returned.
Status ParseImageIds();
// Read annotation from Annotation folder.
// @return Status The status code returned.
Status ParseAnnotationIds();
// Function to parse annotation bbox.
// @param const std::string &path - path to annotation xml.
// @return Status The status code returned.
Status ParseAnnotationBbox(const std::string &path);
// Private function for computing the assignment of the column name map.
// @return Status The status code returned.
Status ComputeColMap() override;
protected:
Status PrepareData() override;
private:
bool decode_;
int64_t row_cnt_;
std::string folder_path_;
std::string usage_;
std::unique_ptr<DataSchema> data_schema_;
std::vector<std::string> image_ids_;
std::map<std::string, uint32_t> label_index_;
std::map<std::string, Annotation> annotation_map_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_KITTI_OP_H_

View File

@ -102,6 +102,7 @@ constexpr char kImageFolderNode[] = "ImageFolderDataset";
constexpr char kIMDBNode[] = "IMDBDataset";
constexpr char kIWSLT2016Node[] = "IWSLT2016Dataset";
constexpr char kIWSLT2017Node[] = "IWSLT2017Dataset";
constexpr char kKITTINode[] = "KITTIDataset";
constexpr char kKMnistNode[] = "KMnistDataset";
constexpr char kLFWNode[] = "LFWDataset";
constexpr char kLibriTTSNode[] = "LibriTTSDataset";

View File

@ -28,6 +28,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SOURCE_SRC_FILES
imdb_node.cc
iwslt2016_node.cc
iwslt2017_node.cc
kitti_node.cc
kmnist_node.cc
lfw_node.cc
libri_tts_node.cc

View File

@ -0,0 +1,146 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/source/kitti_node.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/datasetops/source/kitti_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Constructor for KITTINode
KITTINode::KITTINode(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache = nullptr)
: MappableSourceNode(std::move(cache)),
dataset_dir_(dataset_dir),
usage_(usage),
decode_(decode),
sampler_(sampler) {}
std::shared_ptr<DatasetNode> KITTINode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<KITTINode>(dataset_dir_, usage_, decode_, sampler, cache_);
return node;
}
void KITTINode::Print(std::ostream &out) const {
out << (Name() + "(cache: " + ((cache_ != nullptr) ? "true" : "false") + ")");
}
Status KITTINode::ValidateParams() {
RETURN_IF_NOT_OK(DatasetNode::ValidateParams());
Path dir(dataset_dir_);
RETURN_IF_NOT_OK(ValidateDatasetDirParam("KITTIDataset", dataset_dir_));
RETURN_IF_NOT_OK(ValidateDatasetSampler("KITTIDataset", sampler_));
RETURN_IF_NOT_OK(ValidateStringValue("KITTIDataset", usage_, {"train", "test"}));
return Status::OK();
}
// Function to build KITTINode
Status KITTINode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
RETURN_UNEXPECTED_IF_NULL(node_ops);
auto schema = std::make_unique<DataSchema>();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor(std::string("image"), DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
if (usage_ == "train") {
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor(std::string("label"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor(std::string("truncated"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor(std::string("occluded"), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor(std::string("alpha"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor(std::string("bbox"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor(std::string("dimensions"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor(std::string("location"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor(std::string("rotation_y"), DataType(DataType::DE_FLOAT32), TensorImpl::kFlexible, 1)));
}
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
std::shared_ptr<KITTIOp> kitti_op;
kitti_op = std::make_shared<KITTIOp>(dataset_dir_, usage_, num_workers_, connector_que_size_, decode_,
std::move(schema), std::move(sampler_rt));
kitti_op->SetTotalRepeats(GetTotalRepeats());
kitti_op->SetNumRepeatsPerEpoch(GetNumRepeatsPerEpoch());
node_ops->push_back(kitti_op);
return Status::OK();
}
// Get the shard id of node
Status KITTINode::GetShardId(int32_t *shard_id) {
RETURN_UNEXPECTED_IF_NULL(shard_id);
*shard_id = sampler_->ShardId();
return Status::OK();
}
// Get Dataset size
Status KITTINode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) {
RETURN_UNEXPECTED_IF_NULL(dataset_size);
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
std::vector<std::shared_ptr<DatasetOp>> ops;
RETURN_IF_NOT_OK(Build(&ops));
CHECK_FAIL_RETURN_UNEXPECTED(!ops.empty(), "Unable to build KITTIOp.");
auto op = std::dynamic_pointer_cast<KITTIOp>(ops.front());
RETURN_IF_NOT_OK(op->CountTotalRows(&num_rows));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
if (sample_size == -1) {
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
}
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();
}
Status KITTINode::to_json(nlohmann::json *out_json) {
RETURN_UNEXPECTED_IF_NULL(out_json);
nlohmann::json args, sampler_args;
RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args));
args["sampler"] = sampler_args;
args["num_parallel_workers"] = num_workers_;
args["dataset_dir"] = dataset_dir_;
args["usage"] = usage_;
args["decode"] = decode_;
if (cache_ != nullptr) {
nlohmann::json cache_args;
RETURN_IF_NOT_OK(cache_->to_json(&cache_args));
args["cache"] = cache_args;
}
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,104 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_KITTI_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_KITTI_NODE_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class KITTINode : public MappableSourceNode {
public:
/// \brief Constructor.
/// \param[in] dataset_dir Dataset directory of KITTI.
/// \param[in] usage Usage of this dataset, can be `train` or `test`.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Tells KITTIOp what to read.
/// \param[in] cache Tensor cache to use.
KITTINode(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor.
~KITTINode() = default;
/// \brief Node name getter.
/// \return Name of the current node.
std::string Name() const override { return kKITTINode; }
/// \brief Print the description.
/// \param out The output stream to write output to.
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object.
/// \return A shared pointer to the new copy.
std::shared_ptr<DatasetNode> Copy() override;
/// \brief A base class override function to create the required runtime dataset op objects for this class.
/// \param node_ops A vector containing shared pointer to the Dataset Ops that this object will create.
/// \return Status Status::OK() if build successfully.
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
/// \brief Parameters validation.
/// \return Status Status::OK() if all the parameters are valid.
Status ValidateParams() override;
/// \brief Get the shard id of node.
/// \param[in] shard_id The shard id.
/// \return Status Status::OK() if get shard id successfully.
Status GetShardId(int32_t *shard_id) override;
/// \brief Base-class override for GetDatasetSize.
/// \param[in] size_getter Shared pointer to DatasetSizeGetter.
/// \param[in] estimate This is only supported by some of the ops and it's used to speed up the process of getting
/// dataset size at the expense of accuracy.
/// \param[out] dataset_size the size of the dataset.
/// \return Status of the function.
Status GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_getter, bool estimate,
int64_t *dataset_size) override;
/// \brief Getter functions.
const std::string &DatasetDir() const { return dataset_dir_; }
const std::string &Usage() const { return usage_; }
bool Decode() const { return decode_; }
/// \brief Get the arguments of node.
/// \param[out] out_json JSON string of all attributes.
/// \return Status of the function.
Status to_json(nlohmann::json *out_json) override;
/// \brief Sampler getter.
/// \return SamplerObj of the current node.
std::shared_ptr<SamplerObj> Sampler() override { return sampler_; }
/// \brief Sampler setter.
/// \param[in] sampler Tells KITTIOp what to read.
void SetSampler(std::shared_ptr<SamplerObj> sampler) override { sampler_ = sampler; }
private:
std::string dataset_dir_;
std::string usage_;
bool decode_;
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_KITTI_NODE_H_

View File

@ -3162,6 +3162,105 @@ inline std::shared_ptr<IWSLT2017Dataset> MS_API IWSLT2017(const std::string &dat
shard_id, cache);
}
/// \class KITTIDataset
/// \brief A source dataset that reads KITTI images and labels.
class MS_API KITTIDataset : public Dataset {
public:
/// \brief Constructor of KITTIDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data file to read.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::shared_ptr<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of KITTIDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data file to read.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const Sampler *sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Constructor of KITTIDataset.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data file to read.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Shared pointer to a sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use.
KITTIDataset(const std::vector<char> &dataset_dir, const std::vector<char> &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler, const std::shared_ptr<DatasetCache> &cache);
/// \brief Destructor of KITTIDataset.
~KITTIDataset() = default;
};
/// \brief Function to create a KITTIDataset.
/// \notes When usage is 'train', the generated dataset has multi-columns, 'image', 'label', 'truncated',
/// 'occluded', 'alpha', 'bbox', 'dimensions', 'location', 'rotation_y'; When usage is 'test',
/// the generated dataset has one column 'image'.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data file to read (default = "train").
/// \param[in] decode Decode the images after reading (default = false).
/// \param[in] sampler Sampler object used to choose samples from the dataset. If sampler is not
/// given, a `RandomSampler` will be used to randomly iterate the entire dataset (default = RandomSampler()).
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the current Dataset.
/// \par Example
/// \code
/// /* Define dataset path and MindData object */
/// std::string folder_path = "/path/to/kitti_dataset_directory";
/// std::shared_ptr<Dataset> ds = KITTI(folder_path);
///
/// /* Create iterator to read dataset */
/// std::shared_ptr<Iterator> iter = ds->CreateIterator();
/// std::unordered_map<std::string, mindspore::MSTensor> row;
/// iter->GetNextRow(&row);
///
/// /* Note: In KITTI dataset, each dictionary has key "image" */
/// auto image = row["image"];
/// \endcode
inline std::shared_ptr<KITTIDataset> MS_API
KITTI(const std::string &dataset_dir, const std::string &usage = "train", bool decode = false,
const std::shared_ptr<Sampler> &sampler = std::make_shared<RandomSampler>(),
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<KITTIDataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \brief Function to create a KITTIDataset.
/// \notes When usage is 'train', the generated dataset has multi-columns, 'image', 'label', 'truncated',
/// 'occluded', 'alpha', 'bbox', 'dimensions', 'location', 'rotation_y'; When usage is 'test',
/// the generated dataset has one column 'image'.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data file to read.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the current Dataset.
inline std::shared_ptr<KITTIDataset> MS_API KITTI(const std::string &dataset_dir, const std::string &usage, bool decode,
const Sampler *sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<KITTIDataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \brief Function to create a KITTIDataset.
/// \notes When usage is 'train', the generated dataset has multi-columns, 'image', 'label', 'truncated',
/// 'occluded', 'alpha', 'bbox', 'dimensions', 'location', 'rotation_y'; When usage is 'test',
/// the generated dataset has one column 'image'.
/// \param[in] dataset_dir Path to the root directory that contains the dataset.
/// \param[in] usage The type of data file to read.
/// \param[in] decode Decode the images after reading.
/// \param[in] sampler Sampler object used to choose samples from the dataset.
/// \param[in] cache Tensor cache to use (default=nullptr, which means no cache is used).
/// \return Shared pointer to the current Dataset.
inline std::shared_ptr<KITTIDataset> MS_API KITTI(const std::string &dataset_dir, const std::string &usage, bool decode,
const std::reference_wrapper<Sampler> &sampler,
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<KITTIDataset>(StringToChar(dataset_dir), StringToChar(usage), decode, sampler, cache);
}
/// \class KMnistDataset.
/// \brief A source dataset for reading and parsing KMnist dataset.
class MS_API KMnistDataset : public Dataset {

View File

@ -50,6 +50,7 @@ class MS_API Sampler : std::enable_shared_from_this<Sampler> {
friend class GTZANDataset;
friend class ImageFolderDataset;
friend class IMDBDataset;
friend class KITTIDataset;
friend class KMnistDataset;
friend class LFWDataset;
friend class LibriTTSDataset;

View File

@ -50,6 +50,7 @@ __all__ = ["Caltech101Dataset", # Vision
"FlickrDataset", # Vision
"Flowers102Dataset", # Vision
"ImageFolderDataset", # Vision
"KITTIDataset", # Vision
"KMnistDataset", # Vision
"LFWDataset", # Vision
"LSUNDataset", # Vision

View File

@ -32,7 +32,7 @@ import mindspore._c_dataengine as cde
from .datasets import VisionBaseDataset, SourceDataset, MappableDataset, Shuffle, Schema
from .datasets_user_defined import GeneratorDataset
from .validators import check_imagefolderdataset, \
from .validators import check_imagefolderdataset, check_kittidataset,\
check_mnist_cifar_dataset, check_manifestdataset, check_vocdataset, check_cocodataset, \
check_celebadataset, check_flickr_dataset, check_sb_dataset, check_flowers102dataset, check_cityscapes_dataset, \
check_usps_dataset, check_div2k_dataset, check_random_dataset, \
@ -2299,6 +2299,152 @@ class ImageFolderDataset(MappableDataset, VisionBaseDataset):
return cde.ImageFolderNode(self.dataset_dir, self.decode, self.sampler, self.extensions, self.class_indexing)
class KITTIDataset(MappableDataset):
"""
A source dataset that reads and parses the KITTI dataset.
When usage is "train", the generated dataset has multiple columns: :py:obj:`[image, label, truncated,
occluded, alpha, bbox, dimensions, location, rotation_y]`; When usage is "test", the generated dataset
has only one column: :py:obj:`[image]`.
The tensor of column :py:obj:`image` is of the uint8 type.
The tensor of column :py:obj:`label` is of the uint32 type.
The tensor of column :py:obj:`truncated` is of the float32 type.
The tensor of column :py:obj:`occluded` is of the uint32 type.
The tensor of column :py:obj:`alpha` is of the float32 type.
The tensor of column :py:obj:`bbox` is of the float32 type.
The tensor of column :py:obj:`dimensions` is of the float32 type.
The tensor of column :py:obj:`location` is of the float32 type.
The tensor of column :py:obj:`rotation_y` is of the float32 type.
Args:
dataset_dir (str): Path to the root directory that contains the dataset.
usage (str, optional): Usage of this dataset, can be `train` or `test`. `train` will read 7481
train samples, `test` will read from 7518 test samples without label (default=None, will use `train`).
num_samples (int, optional): The number of images to be included in the dataset
(default=None, will include all images).
num_parallel_workers (int, optional): Number of workers to read the data
(default=None, number set in the config).
shuffle (bool, optional): Whether to perform shuffle on the dataset (default=None, expected
order behavior shown in the table).
decode (bool, optional): Decode the images after reading (default=False).
sampler (Sampler, optional): Object used to choose samples from the dataset
(default=None, expected order behavior shown in the table).
num_shards (int, optional): Number of shards that the dataset will be divided
into (default=None). When this argument is specified, 'num_samples' reflects
the max sample number of per shard.
shard_id (int, optional): The shard ID within num_shards (default=None). This
argument can only be specified when num_shards is also specified.
cache (DatasetCache, optional): Use tensor caching service to speed up dataset processing
(default=None, which means no cache is used).
Raises:
RuntimeError: If `sampler` and `shuffle` are specified at the same time.
RuntimeError: If `sampler` and `num_shards`/`shard_id` are specified at the same time.
RuntimeError: If `num_shards` is specified but `shard_id` is None.
RuntimeError: If `shard_id` is specified but `num_shards` is None.
ValueError: If `dataset_dir` is not exist.
ValueError: If `shard_id` is invalid (< 0 or >= num_shards).
Note:
- This dataset can take in a `sampler`. `sampler` and `shuffle` are mutually exclusive.
The table below shows what input arguments are allowed and their expected behavior.
.. list-table:: Expected Order Behavior of Using `sampler` and `shuffle`
:widths: 25 25 50
:header-rows: 1
* - Parameter `sampler`
- Parameter `shuffle`
- Expected Order Behavior
* - None
- None
- random order
* - None
- True
- random order
* - None
- False
- sequential order
* - Sampler object
- None
- order defined by sampler
* - Sampler object
- True
- not allowed
* - Sampler object
- False
- not allowed
Examples:
>>> kitti_dataset_dir = "/path/to/kitti_dataset_directory"
>>>
>>> # 1) Read all KITTI train dataset samples in kitti_dataset_dir in sequence
>>> dataset = ds.KITTIDataset(dataset_dir=kitti_dataset_dir, usage="train")
>>>
>>> # 2) Read then decode all KITTI test dataset samples in kitti_dataset_dir in sequence
>>> dataset = ds.KITTIDataset(dataset_dir=kitti_dataset_dir, usage="test",
... decode=True, shuffle=False)
About KITTI dataset:
KITTI (Karlsruhe Institute of Technology and Toyota Technological Institute) is one of the most popular
datasets for use in mobile robotics and autonomous driving. It consists of hours of traffic scenarios
recorded with a variety of sensor modalities, including high-resolution RGB, grayscale stereo cameras,
and a 3D laser scanner. Despite its popularity, the dataset itself does not contain ground truth for
semantic segmentation. However, various researchers have manually annotated parts of the dataset to fit
their necessities. Álvarez et al. generated ground truth for 323 images from the road detection challenge
with three classes: road, vertical,and sky. Zhang et al. annotated 252 (140 for training and 112 for testing)
acquisitions RGB and Velodyne scans from the tracking challenge for ten object categories: building, sky,
road, vegetation, sidewalk, car, pedestrian, cyclist, sign/pole, and fence.
You can unzip the original KITTI dataset files into this directory structure and read by MindSpore's API.
.. code-block::
.
kitti_dataset_directory
data_object_image_2
training
image_2
000000000001.jpg
000000000002.jpg
...
testing
image_2
000000000001.jpg
000000000002.jpg
...
data_object_label_2
training
label_2
000000000001.jpg
000000000002.jpg
...
Citation:
.. code-block::
@INPROCEEDINGS{Geiger2012CVPR,
author={Andreas Geiger and Philip Lenz and Raquel Urtasun},
title={Are we ready for Autonomous Driving? The KITTI Vision Benchmark Suite},
booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2012}
}
"""
@check_kittidataset
def __init__(self, dataset_dir, usage=None, num_samples=None, num_parallel_workers=None, shuffle=None,
decode=False, sampler=None, num_shards=None, shard_id=None, cache=None):
super().__init__(num_parallel_workers=num_parallel_workers, sampler=sampler, num_samples=num_samples,
shuffle=shuffle, num_shards=num_shards, shard_id=shard_id, cache=cache)
self.dataset_dir = dataset_dir
self.usage = replace_none(usage, "train")
self.decode = replace_none(decode, False)
def parse(self, children=None):
return cde.KITTINode(self.dataset_dir, self.usage, self.decode, self.sampler)
class KMnistDataset(MappableDataset, VisionBaseDataset):
"""
A source dataset that reads and parses the KMNIST dataset.

View File

@ -258,6 +258,35 @@ def check_iwslt2017_dataset(method):
return new_method
def check_kittidataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(KITTIDataset)."""
@wraps(method)
def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs)
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_bool = ['shuffle', 'decode']
dataset_dir = param_dict.get('dataset_dir')
check_dir(dataset_dir)
usage = param_dict.get('usage')
if usage is not None:
check_valid_str(usage, ["train", "test"], "usage")
validate_dataset_param_value(nreq_param_int, param_dict, int)
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
def check_lsun_dataset(method):
"""A wrapper that wraps a parameter checker around the original Dataset(LSUNDataset)."""

View File

@ -37,6 +37,7 @@ SET(DE_UT_SRCS
c_api_dataset_imdb_test.cc
c_api_dataset_iterator_test.cc
c_api_dataset_iwslt_test.cc
c_api_dataset_kitti_test.cc
c_api_dataset_kmnist_test.cc
c_api_dataset_lfw_test.cc
c_api_dataset_libri_tts.cc

View File

@ -0,0 +1,249 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/core/tensor.h"
using namespace mindspore::dataset;
using mindspore::dataset::DataType;
using mindspore::dataset::Tensor;
using mindspore::dataset::TensorShape;
class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
/// Feature: KITTIDatasetPipeline.
/// Description: test Pipeline of KITTI.
/// Expectation: get correct data.
TEST_F(MindDataTestPipeline, TestKITTIPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIPipeline.";
// Create a KITTI Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "train", false, std::make_shared<SequentialSampler>(0, 2));
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row.
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
// Check if KITTI() read correct images.
std::string expect_file[] = {"000000", "000001", "000002"};
uint64_t i = 0;
while (row.size() != 0) {
auto image = row["image"];
auto label = row["label"];
MS_LOG(INFO) << "Tensor image shape: " << image.Shape();
MS_LOG(INFO) << "Tensor label shape: " << label.Shape();
mindspore::MSTensor expect_image =
ReadFileToTensor(folder_path + "/data_object_image_2/training/image_2/" + expect_file[i] + ".png");
EXPECT_MSTENSOR_EQ(image, expect_image);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 2);
// Manually terminate the pipeline.
iter->Stop();
}
/// Feature: KITTITrainDatasetGetters.
/// Description: test usage of getters KITTITrainDataset.
/// Expectation: get correct number of data and correct tensor shape.
TEST_F(MindDataTestPipeline, TestKITTITrainDatasetGetters) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTITrainDatasetGetters.";
// Create a KITTI Train Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "train", true, std::make_shared<SequentialSampler>(0, 2));
EXPECT_NE(ds, nullptr);
EXPECT_EQ(ds->GetDatasetSize(), 2);
std::vector<DataType> types = ToDETypes(ds->GetOutputTypes());
std::vector<std::string> column_names = {"image", "label", "truncated", "occluded", "alpha",
"bbox", "dimensions", "location", "rotation_y"};
int64_t num_classes = ds->GetNumClasses();
EXPECT_EQ(types.size(), 9);
EXPECT_EQ(types[0].ToString(), "uint8");
EXPECT_EQ(types[1].ToString(), "uint32");
EXPECT_EQ(types[2].ToString(), "float32");
EXPECT_EQ(types[3].ToString(), "uint32");
EXPECT_EQ(types[4].ToString(), "float32");
EXPECT_EQ(types[5].ToString(), "float32");
EXPECT_EQ(types[6].ToString(), "float32");
EXPECT_EQ(types[7].ToString(), "float32");
EXPECT_EQ(types[8].ToString(), "float32");
EXPECT_EQ(num_classes, -1);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetColumnNames(), column_names);
EXPECT_EQ(ds->GetDatasetSize(), 2);
EXPECT_EQ(ToDETypes(ds->GetOutputTypes()), types);
EXPECT_EQ(ds->GetBatchSize(), 1);
EXPECT_EQ(ds->GetRepeatCount(), 1);
EXPECT_EQ(ds->GetNumClasses(), -1);
EXPECT_EQ(ds->GetDatasetSize(), 2);
}
/// Feature: KITTIUsageTrainDecodeFalse.
/// Description: test get train dataset of KITTI and test decode.
/// Expectation: getters of KITTI get the correct value.
TEST_F(MindDataTestPipeline, TestKITTIUsageTrainDecodeFalse) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIGettersTrainDecodeFalse.";
// Create a KITTI Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "train", false, std::make_shared<SequentialSampler>(0, 2));
EXPECT_NE(ds, nullptr);
ds = ds->Batch(2);
ds = ds->Repeat(2);
EXPECT_EQ(ds->GetDatasetSize(), 2);
std::vector<std::string> column_names = {"image", "label", "truncated", "occluded", "alpha",
"bbox", "dimensions", "location", "rotation_y"};
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: TestKITTIUsageTestDecodeTrue.
/// Description: test get test dataset of KITTI and test the decode.
/// Expectation: getters of KITTI get the correct value.
TEST_F(MindDataTestPipeline, TestKITTIUsageTestDecodeTrue) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIGettersTestDecodeTrue.";
// Create a KITTI Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "test", true, std::make_shared<SequentialSampler>(0, 2));
EXPECT_NE(ds, nullptr);
ds = ds->Batch(2);
ds = ds->Repeat(2);
EXPECT_EQ(ds->GetDatasetSize(), 2);
std::vector<std::string> column_names = {"image"};
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: KITTIPipelineRandomSampler.
/// Description: test RandomSampler of KITTI.
/// Expectation: getters of KITTI get the correct value.
TEST_F(MindDataTestPipeline, TestKITTIPipelineRandomSampler) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIPipelineRandomSampler.";
// Create a KITTI Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "test", true, std::make_shared<RandomSampler>(false, 2));
EXPECT_NE(ds, nullptr);
ds = ds->Batch(2);
ds = ds->Repeat(2);
EXPECT_EQ(ds->GetDatasetSize(), 2);
std::vector<std::string> column_names = {"image"};
EXPECT_EQ(ds->GetColumnNames(), column_names);
}
/// Feature: KITTIPipelineDistributedSampler.
/// Description: test DistributedSampler of KITTI.
/// Expectation: getters of KITTI get the correct value.
TEST_F(MindDataTestPipeline, TestKITTIPipelineDistributedSampler) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIPipelineDistributedSampler.";
// Create a KITTI Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
// num_shards=3, shard_id=0, shuffle=false, num_samplers=0, seed=0, offset=-1, even_dist=true
DistributedSampler sampler = DistributedSampler(3, 0, false, 0, 0, -1, true);
std::shared_ptr<Dataset> ds = KITTI(folder_path, "train", false, sampler);
EXPECT_NE(ds, nullptr);
// Iterate the dataset and get each row
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto label = row["image"];
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 1);
iter->Stop();
}
/// Feature: KITTIDatasetWithNullSampler.
/// Description: test null sampler of KITTI.
/// Expectation: throw exception correctly.
TEST_F(MindDataTestPipeline, TestKITTIWithNullSamplerError) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIWithNullSamplerError.";
// Create a KITTI Dataset.
std::string folder_path = datasets_root_path_ + "/testKITTI";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "train", false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid KITTI input, sampler cannot be nullptr.
EXPECT_EQ(iter, nullptr);
}
/// Feature: KITTIDatasetWithNullPath.
/// Description: test null path of KITTI.
/// Expectation: throw exception correctly.
TEST_F(MindDataTestPipeline, TestKITTIWithNullPath) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIWithNullPath.";
// Create a KITTI Dataset.
std::string folder_path = "";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "train", false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid KITTI input, path cannot be "".
EXPECT_EQ(iter, nullptr);
}
/// Feature: KITTIDatasetWithWrongUsage.
/// Description: test wrong usage of KITTI.
/// Expectation: throw exception correctly.
TEST_F(MindDataTestPipeline, TestKITTIWithWrongUsage) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestKITTIWithWrongUsage.";
// Create a KITTI Dataset.
std::string folder_path = "";
std::shared_ptr<Dataset> ds = KITTI(folder_path, "all", false, nullptr);
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure: invalid KITTI input, path cannot be "".
EXPECT_EQ(iter, nullptr);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 172 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

View File

@ -0,0 +1,6 @@
Pedestrian 0.00 0 -0.20 712.40 143.00 810.73 307.92 1.89 0.48 1.20 1.84 1.47 8.41 0.01
Car 0.00 0 -1.84 662.20 185.85 690.21 205.03 1.48 1.36 3.51 5.35 2.56 58.84 -1.75
Van 0.00 0 1.70 448.07 177.14 481.60 206.41 2.50 2.20 5.78 -13.02 2.91 65.02 1.50
DontCare -1 -1 -10 610.50 179.95 629.68 196.31 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 582.97 182.70 594.78 191.05 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 600.36 185.59 608.36 192.69 -1 -1 -1 -1000 -1000 -1000 -10

View File

@ -0,0 +1,3 @@
Car 0.00 0 1.55 614.24 181.78 727.31 284.77 1.57 1.73 4.15 1.00 1.75 13.22 1.62
DontCare -1 -1 -10 5.00 229.89 214.12 367.61 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 522.25 202.35 547.77 219.71 -1 -1 -1 -1000 -1000 -1000 -10

View File

@ -0,0 +1,7 @@
Car 0.00 0 1.96 280.38 185.10 344.90 215.59 1.49 1.76 4.01 -15.71 2.16 38.26 1.57
Car 0.00 0 1.88 365.14 184.54 406.11 205.20 1.38 1.80 3.41 -15.89 2.23 51.17 1.58
DontCare -1 -1 -10 402.27 166.69 477.31 197.98 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 518.53 177.31 531.51 187.17 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 1207.50 233.35 1240.00 333.39 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 535.06 177.65 545.26 185.82 -1 -1 -1 -1000 -1000 -1000 -10
DontCare -1 -1 -10 558.03 177.88 567.50 184.65 -1 -1 -1 -1000 -1000 -1000 -10

View File

@ -0,0 +1,348 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
import pytest
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = "../data/dataset/testKITTI"
IMAGE_SHAPE = [2268, 642, 2268]
def test_func_kitti_dataset_basic():
"""
Feature: KITTI
Description: test basic function of KITTI with default parament
Expectation: the dataset is as expected
"""
repeat_count = 2
# apply dataset operations.
data = ds.KITTIDataset(DATA_DIR, shuffle=False)
data = data.repeat(repeat_count)
num_iter = 0
count = [0, 0, 0, 0, 0, 0, 0, 0]
SHAPE = [159109, 176455, 54214, 159109, 176455, 54214]
ANNOTATIONSHAPE = [6, 3, 7, 6, 3, 7]
# each data is a dictionary.
for item in data.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image", "label", "truncated", "occluded", "alpha", "bbox",
# "dimensions", "location", "rotation_y".
assert item["image"].shape[0] == SHAPE[num_iter]
for label in item["label"]:
count[label[0]] += 1
assert item["truncated"].shape[0] == ANNOTATIONSHAPE[num_iter]
assert item["occluded"].shape[0] == ANNOTATIONSHAPE[num_iter]
assert item["alpha"].shape[0] == ANNOTATIONSHAPE[num_iter]
assert item["bbox"].shape[0] == ANNOTATIONSHAPE[num_iter]
assert item["dimensions"].shape[0] == ANNOTATIONSHAPE[num_iter]
assert item["location"].shape[0] == ANNOTATIONSHAPE[num_iter]
assert item["rotation_y"].shape[0] == ANNOTATIONSHAPE[num_iter]
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
assert count == [8, 20, 2, 2, 0, 0, 0, 0]
def test_kitti_usage_train():
"""
Feature: KITTI
Description: test basic usage "train" of KITTI
Expectation: the dataset is as expected
"""
data1 = ds.KITTIDataset(DATA_DIR, usage="train")
num = 0
count = [0, 0, 0, 0, 0, 0, 0, 0]
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
for label in item["label"]:
count[label[0]] += 1
num += 1
assert num == 3
assert count == [4, 10, 1, 1, 0, 0, 0, 0]
def test_kitti_usage_test():
"""
Feature: KITTI
Description: test basic usage "test" of KITTI
Expectation: the dataset is as expected
"""
data1 = ds.KITTIDataset(
DATA_DIR, usage="test", shuffle=False, decode=True, num_samples=3)
num = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
assert item["image"].shape[0] == IMAGE_SHAPE[num]
num += 1
assert num == 3
def test_kitti_case():
"""
Feature: KITTI
Description: test basic usage of KITTI
Expectation: the dataset is as expected
"""
data1 = ds.KITTIDataset(DATA_DIR,
usage="train", decode=True, num_samples=3)
resize_op = vision.Resize((224, 224))
data1 = data1.map(operations=resize_op, input_columns=["image"])
repeat_num = 4
data1 = data1.repeat(repeat_num)
batch_size = 2
data1 = data1.batch(batch_size, drop_remainder=True, pad_info={})
num = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num += 1
assert num == 6
def test_func_kitti_dataset_numsamples_num_parallel_workers():
"""
Feature: KITTI
Description: test numsamples and num_parallel_workers of KITTI
Expectation: the dataset is as expected
"""
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.KITTIDataset(DATA_DIR, num_samples=2, num_parallel_workers=2)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 4
random_sampler = ds.RandomSampler(num_samples=3, replacement=True)
data1 = ds.KITTIDataset(DATA_DIR, num_parallel_workers=2,
sampler=random_sampler)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
random_sampler = ds.RandomSampler(num_samples=3, replacement=False)
data1 = ds.KITTIDataset(DATA_DIR, num_parallel_workers=2,
sampler=random_sampler)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert num_iter == 3
def test_func_kitti_dataset_extrashuffle():
"""
Feature: KITTI
Description: test extrashuffle of KITTI
Expectation: the dataset is as expected
"""
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.KITTIDataset(DATA_DIR, shuffle=True)
data1 = data1.shuffle(buffer_size=3)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_func_kitti_dataset_no_para():
"""
Feature: KITTI
Description: test no para of KITTI
Expectation: throw exception correctly
"""
with pytest.raises(TypeError, match="missing a required argument: 'dataset_dir'"):
dataset = ds.KITTIDataset()
num_iter = 0
for data in dataset.create_dict_iterator(output_numpy=True):
assert "image" in str(data.keys())
num_iter += 1
def test_func_kitti_dataset_distributed_sampler():
"""
Feature: KITTI
Description: test DistributedSampler of KITTI
Expectation: throw exception correctly
"""
# define parameters.
repeat_count = 2
# apply dataset operations.
sampler = ds.DistributedSampler(3, 1)
data1 = ds.KITTIDataset(DATA_DIR, sampler=sampler)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_func_kitti_dataset_decode():
"""
Feature: KITTI
Description: test decode of KITTI
Expectation: throw exception correctly
"""
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.KITTIDataset(DATA_DIR, decode=True)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
# in this example, each dictionary has keys "image" and "label".
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 6
def test_kitti_numshards():
"""
Feature: KITTI
Description: test numShards of KITTI
Expectation: throw exception correctly
"""
# define parameters.
repeat_count = 2
# apply dataset operations.
data1 = ds.KITTIDataset(DATA_DIR, num_shards=3, shard_id=2)
data1 = data1.repeat(repeat_count)
num_iter = 0
# each data is a dictionary.
for _ in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
logger.info("Number of data in data1: {}".format(num_iter))
assert num_iter == 2
def test_func_kitti_dataset_more_para():
"""
Feature: KITTI
Description: test more para of KITTI
Expectation: throw exception correctly
"""
with pytest.raises(TypeError, match="got an unexpected keyword argument 'more_para'"):
dataset = ds.KITTIDataset(DATA_DIR, usage="train", num_samples=6, num_parallel_workers=None,
shuffle=True, sampler=None, decode=True, num_shards=3,
shard_id=2, cache=None, more_para=None)
num_iter = 0
for data in dataset.create_dict_iterator(output_numpy=True):
num_iter += 1
assert "image" in str(data.keys())
def test_kitti_exception():
"""
Feature: KITTI
Description: test error cases of KITTI
Expectation: throw exception correctly
"""
logger.info("Test error cases for KITTIDataset")
error_msg_1 = "sampler and shuffle cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_1):
ds.KITTIDataset(DATA_DIR, shuffle=False, decode=True, sampler=ds.SequentialSampler(1))
error_msg_2 = "sampler and sharding cannot be specified at the same time"
with pytest.raises(RuntimeError, match=error_msg_2):
ds.KITTIDataset(DATA_DIR, sampler=ds.SequentialSampler(1), decode=True, num_shards=2, shard_id=0)
error_msg_3 = "num_shards is specified and currently requires shard_id as well"
with pytest.raises(RuntimeError, match=error_msg_3):
ds.KITTIDataset(DATA_DIR, decode=True, num_shards=10)
error_msg_4 = "shard_id is specified but num_shards is not"
with pytest.raises(RuntimeError, match=error_msg_4):
ds.KITTIDataset(DATA_DIR, decode=True, shard_id=0)
error_msg_5 = "Input shard_id is not within the required interval"
with pytest.raises(ValueError, match=error_msg_5):
ds.KITTIDataset(DATA_DIR, decode=True, num_shards=5, shard_id=-1)
with pytest.raises(ValueError, match=error_msg_5):
ds.KITTIDataset(DATA_DIR, decode=True, num_shards=5, shard_id=5)
error_msg_6 = "num_parallel_workers exceeds"
with pytest.raises(ValueError, match=error_msg_6):
ds.KITTIDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=0)
with pytest.raises(ValueError, match=error_msg_6):
ds.KITTIDataset(DATA_DIR, decode=True, shuffle=False, num_parallel_workers=256)
error_msg_7 = "Argument shard_id"
with pytest.raises(TypeError, match=error_msg_7):
ds.KITTIDataset(DATA_DIR, decode=True, num_shards=2, shard_id="0")
error_msg_8 = "does not exist or is not a directory or permission denied!"
with pytest.raises(ValueError, match=error_msg_8):
all_data = ds.KITTIDataset("../data/dataset/testKITTI2", decode=True)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_9 = "Input usage is not within the valid set of ['train', 'test']."
with pytest.raises(ValueError, match=re.escape(error_msg_9)):
all_data = ds.KITTIDataset(DATA_DIR, usage="all")
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
error_msg_10 = "Argument decode with value 123 is not of type [<class 'bool'>], but got <class 'int'>."
with pytest.raises(TypeError, match=re.escape(error_msg_10)):
all_data = ds.KITTIDataset(DATA_DIR, decode=123)
for _ in all_data.create_dict_iterator(num_epochs=1):
pass
if __name__ == '__main__':
test_func_kitti_dataset_basic()
test_kitti_usage_train()
test_kitti_usage_test()
test_kitti_case()
test_func_kitti_dataset_numsamples_num_parallel_workers()
test_func_kitti_dataset_extrashuffle()
test_func_kitti_dataset_no_para()
test_func_kitti_dataset_distributed_sampler()
test_func_kitti_dataset_decode()
test_kitti_numshards()
test_func_kitti_dataset_more_para()
test_kitti_exception()