!13378 【MD】【Task】add data help for MD lite

From: @xulei2020
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-03-17 09:13:43 +08:00 committed by Gitee
commit 888b2e19ee
10 changed files with 803 additions and 50 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONSTANTS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONSTANTS_H_
#include <cstdint>
#include <limits>
@ -27,7 +27,7 @@ using uchar = unsigned char;
using dsize_t = int64_t;
// Target devices to perform map operation
enum class MapTargetDevice { kCpu, kGpu, kDvpp };
enum class MapTargetDevice { kCpu, kGpu, kAscend310 };
// Possible dataset types for holding the data and client type
enum class DatasetType { kUnknown, kArrow, kTf };
@ -71,6 +71,9 @@ enum class NormalizeForm {
kNfkd,
};
// Possible values for SamplingStrategy
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
// convenience functions for 32bit int bitmask
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }
@ -84,7 +87,7 @@ constexpr int64_t kDeMaxFreq = std::numeric_limits<int64_t>::max(); // 92233720
constexpr int64_t kDeMaxTopk = std::numeric_limits<int64_t>::max();
constexpr uint32_t kCfgRowsPerBuffer = 1;
constexpr uint32_t kCfgParallelWorkers = 4;
constexpr uint32_t kCfgParallelWorkers = 8;
constexpr uint32_t kCfgWorkerConnectorSize = 16;
constexpr uint32_t kCfgOpConnectorSize = 16;
constexpr int32_t kCfgDefaultRankId = -1;
@ -106,4 +109,4 @@ using row_id_type = int64_t;
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONSTANTS_H_

View File

@ -0,0 +1,448 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATA_HELPER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATA_HELPER_H_
#include <sys/stat.h>
#include <fstream>
#include <iostream>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
namespace mindspore {
namespace dataset {
/// \brief Simple class to do data manipulation, contains helper function to update json files in dataset
class DataHelper {
public:
/// \brief constructor
DataHelper() {}
/// \brief Destructor
~DataHelper() = default;
/// \brief Create an Album dataset while taking in a path to a image folder
/// Creates the output directory if doesn't exist
/// \param[in] in_dir Image folder directory that takes in images
/// \param[in] out_dir Directory containing output json files
Status CreateAlbum(const std::string &in_dir, const std::string &out_dir) {
return CreateAlbumIF(StringToChar(in_dir), StringToChar(out_dir));
}
/// \brief Update a json file field with a vector of string values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional input for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), VectorStringToChar(value), StringToChar(out_file));
}
/// \brief Update a json file field with a vector of bool values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<bool> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of int8 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int8_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of uint8 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint8_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of int16 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int16_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of uint16 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint16_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of int32 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int32_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of uint32 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint32_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of int64 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int64_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of uint64 values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint64_t> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of float values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<float> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a vector of double values
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value array to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<double> &value,
const std::string &out_file = "") {
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a string value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const std::string &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), StringToChar(value), StringToChar(out_file));
}
/// \brief Update a json file field with a bool value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const bool &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an int8 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const int8_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an uint8 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const uint8_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an int16 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const int16_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an uint16 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const uint16_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an int32 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const int32_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an uint32 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const uint32_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an int64 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const int64_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with an uint64 value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const uint64_t &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a float value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const float &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Update a json file field with a double value
/// \param in_file The input file name to read in
/// \param key Key of field to write to
/// \param value Value to write to file
/// \param out_file Optional parameter for output file path, will write to input file if not specified
/// \return Status The status code returned
Status UpdateValue(const std::string &in_file, const std::string &key, const double &value,
const std::string &out_file = "") {
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
}
/// \brief Template function to write tensor to file
/// \param[in] in_file File to write to
/// \param[in] data Array of type T values
/// \return Status The status code returned
template <typename T>
Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) {
try {
std::ofstream o(in_file, std::ios::binary | std::ios::out);
if (!o.is_open()) {
return Status(kMDUnexpectedError, "Error opening Bin file to write");
}
size_t length = data.size();
o.write(reinterpret_cast<const char *>(&data[0]), std::streamsize(length * sizeof(T)));
o.close();
}
// Catch any exception and convert to Status return code
catch (const std::exception &err) {
return Status(kMDUnexpectedError, "Write bin file failed ");
}
return Status::OK();
}
/// \brief Write pointer to bin, use pointer to avoid memcpy
/// \param[in] in_file File name to write to
/// \param[in] data Pointer to data
/// \param[in] length Length of values to write from pointer
/// \return Status The status code returned
template <typename T>
Status WriteBinFile(const std::string &in_file, T *data, size_t length) {
try {
std::ofstream o(in_file, std::ios::binary | std::ios::out);
if (!o.is_open()) {
return Status(kMDUnexpectedError, "Error opening Bin file to write");
}
o.write(reinterpret_cast<const char *>(data), std::streamsize(length * sizeof(T)));
o.close();
}
// Catch any exception and convert to Status return code
catch (const std::exception &err) {
return Status(kMDUnexpectedError, "Write bin file failed ");
}
return Status::OK();
}
/// \brief Helper function to copy content of a tensor to buffer
/// \note This function iterates over the tensor in bytes, since
/// \param[in] tensor_addr The memory held by a tensor
/// \param[in] tensor_size The amount of data in bytes in tensor_addr, e.g. tensor->SizeInBytes()
/// \param[out] addr The address to copy tensor data to
/// \param[in] buffer_size The buffer size of addr
/// \return The size of the tensor (bytes copied
size_t DumpData(const unsigned char *tensor_addr, const size_t &tensor_size, void *addr, const size_t &buffer_size);
/// \brief Helper function to delete key in json file
/// note This function will return okay even if key not found
/// \param[in] in_file Json file to remove key from
/// \param[in] key The key to remove
/// \return Status The status code returned
Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = "") {
return RemoveKeyIF(StringToChar(in_file), StringToChar(key), StringToChar(out_file));
}
/// \brief A print method typically used for debugging
/// \param out - The output stream to write output to
void Print(std::ostream &out) const;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out Reference to the output stream being overloaded
/// \param ds Reference to the DataSchema to display
/// \return The output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const DataHelper &dh) {
dh.Print(out);
return out;
}
private:
// Helper function for dual ABI support
Status CreateAlbumIF(const std::vector<char> &in_dir, const std::vector<char> &out_dir);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<std::vector<char>> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<bool> &value,
const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<int8_t> &value,
const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<uint8_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<int16_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<uint16_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<int32_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<uint32_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<int64_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
const std::vector<uint64_t> &value, const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<float> &value,
const std::vector<char> &out_file);
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<double> &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<char> &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const bool &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int8_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint8_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int16_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint16_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int32_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint32_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int64_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint64_t &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const float &value,
const std::vector<char> &out_file);
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const double &value,
const std::vector<char> &out_file);
Status RemoveKeyIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<char> &out_file);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATA_HELPER_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@
#include <sys/stat.h>
#include <unistd.h>
#include <algorithm>
#include <map>
#include <memory>
@ -30,6 +31,7 @@
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/types.h"
#include "include/iterator.h"
#include "include/samplers.h"
#include "include/transforms.h"
@ -39,11 +41,15 @@ namespace dataset {
class Tensor;
class TensorShape;
class TreeAdapter;
class TreeAdapterLite;
class TreeGetters;
class DatasetCache;
class DatasetNode;
class Iterator;
class TensorOperation;
class SchemaObj;
class SamplerObj;
@ -75,13 +81,13 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return dataset size. If failed, return -1
int64_t GetDatasetSize(bool estimate = false);
// /// \brief Gets the output type
// /// \return a vector of DataType. If failed, return an empty vector
// std::vector<DataType> GetOutputTypes();
/// \brief Gets the output type
/// \return a vector of DataType. If failed, return an empty vector
std::vector<mindspore::DataType> GetOutputTypes();
/// \brief Gets the output shape
/// \return a vector of TensorShape. If failed, return an empty vector
std::vector<TensorShape> GetOutputShapes();
std::vector<std::vector<int64_t>> GetOutputShapes();
/// \brief Gets the batch size
/// \return int64_t
@ -110,6 +116,11 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the original object
std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
/// \brief Function to create an PullBasedIterator over the Dataset
/// \param[in] columns List of columns to be used to specify the order of columns
/// \return Shared pointer to the Iterator
std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {});
/// \brief Function to create an Iterator over the Dataset pipeline
/// \param[in] columns List of columns to be used to specify the order of columns
/// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
@ -119,6 +130,41 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
}
/// \brief Function to transfer data through a device.
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
/// of data transmission per time is 256M.
/// \param[in] queue_name Channel name (default="", create new unique name).
/// \param[in] device_type Type of device (default="", get from MSContext).
/// \param[in] device_id id of device (default=1, get from MSContext).
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
/// of data or not(default=false).
/// \return Returns true if no error encountered else false.
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
bool create_data_info_queue = false) {
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
total_batches, create_data_info_queue);
}
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
/// \note Usage restrictions:
/// 1. Supported dataset formats: 'mindrecord' only
/// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1.
/// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
/// with random attribute in map operator.
/// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor
/// multi-dimensional string.
/// \param[in] file_name Path to dataset file
/// \param[in] num_files Number of dataset files (default=1)
/// \param[in] file_type Dataset format (default="mindrecord")
/// \return Returns true if no error encountered else false
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
}
/// \brief Function to create a BatchDataset
/// \notes Combines batch_size number of consecutive rows into batches
/// \param[in] batch_size The number of rows each batch is created with
@ -131,8 +177,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
/// applied in the order they appear in this list
/// \param[in] operations Vector of raw pointers to TensorTransform objects to be applied on the dataset. Operations
/// are applied in the order they appear in this list
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
/// operation as input. The size of this list must match the number of
/// input columns expected by the first operator. The default input_columns
@ -160,6 +206,22 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
callbacks);
}
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of shared pointers to TensorTransform objects to be applied on the dataset.
/// Operations are applied in the order they appear in this list
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
/// operation as input. The size of this list must match the number of
/// input columns expected by the first operator. The default input_columns
/// is the first column
/// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
/// This parameter is mandatory if len(input_columns) != len(output_columns)
/// The size of this list must match the number of output columns of the
/// last operation. The default output_columns will have the same
/// name as the input columns, i.e., the columns will be replaced
/// \param[in] project_columns A list of column names to project
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current MapDataset
std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {},
@ -176,6 +238,22 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
callbacks);
}
/// \brief Function to create a MapDataset
/// \notes Applies each operation in operations to this dataset
/// \param[in] operations Vector of TensorTransform objects to be applied on the dataset. Operations are applied in
/// the order they appear in this list
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
/// operation as input. The size of this list must match the number of
/// input columns expected by the first operator. The default input_columns
/// is the first column
/// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
/// This parameter is mandatory if len(input_columns) != len(output_columns)
/// The size of this list must match the number of output columns of the
/// last operation. The default output_columns will have the same
/// name as the input columns, i.e., the columns will be replaced
/// \param[in] project_columns A list of column names to project
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
/// \return Shared pointer to the current MapDataset
std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
const std::vector<std::string> &input_columns = {},
const std::vector<std::string> &output_columns = {},
@ -221,6 +299,115 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
// Char interface(CharIF) of CreateIterator
std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
// Char interface(CharIF) of DeviceQueue
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
// Char interface(CharIF) of Save
bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
};
class SchemaObj {
public:
/// \brief Constructor
explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}
/// \brief Destructor
~SchemaObj() = default;
/// \brief SchemaObj Init function
/// \return bool true if schema initialization is successful
Status Init();
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name Name of the column.
/// \param[in] ms_type Data type of the column(mindspore::DataType).
/// \return Status code
Status add_column(const std::string &name, mindspore::DataType ms_type) {
return add_column_char(StringToChar(name), ms_type);
}
/// \brief Add new column to the schema with unknown shape of rank 1
/// \param[in] name Name of the column.
/// \param[in] ms_type Data type of the column(std::string).
/// \param[in] shape Shape of the column.
/// \return Status code
Status add_column(const std::string &name, const std::string &ms_type) {
return add_column_char(StringToChar(name), StringToChar(ms_type));
}
/// \brief Add new column to the schema
/// \param[in] name Name of the column.
/// \param[in] ms_type Data type of the column(mindspore::DataType).
/// \param[in] shape Shape of the column.
/// \return Status code
Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
return add_column_char(StringToChar(name), ms_type, shape);
}
/// \brief Add new column to the schema
/// \param[in] name Name of the column.
/// \param[in] ms_type Data type of the column(std::string).
/// \param[in] shape Shape of the column.
/// \return Status code
Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
}
/// \brief Get a JSON string of the schema
/// \return JSON string of the schema
std::string to_json() { return CharToString(to_json_char()); }
/// \brief Get a JSON string of the schema
std::string to_string() { return to_json(); }
/// \brief Set a new value to dataset_type
void set_dataset_type(std::string dataset_type);
/// \brief Set a new value to num_rows
void set_num_rows(int32_t num_rows);
/// \brief Get the current num_rows
int32_t get_num_rows() const;
/// \brief Get schema file from JSON file
/// \param[in] json_string Name of JSON file to be parsed.
/// \return Status code
Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }
/// \brief Parse and add column information
/// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
/// \return Status code
Status ParseColumnString(const std::string &json_string) {
return ParseColumnStringCharIF(StringToChar(json_string));
}
private:
// Char constructor of SchemaObj
explicit SchemaObj(const std::vector<char> &schema_file);
// Char interface of add_column
Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);
Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);
Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);
Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
const std::vector<int32_t> &shape);
// Char interface of to_json
const std::vector<char> to_json_char();
// Char interface of FromJSONString
Status FromJSONStringCharIF(const std::vector<char> &json_string);
// Char interface of ParseColumnString
Status ParseColumnStringCharIF(const std::vector<char> &json_string);
struct Data;
std::shared_ptr<Data> data_;
};
class BatchDataset : public Dataset {
@ -252,12 +439,17 @@ class ShuffleDataset : public Dataset {
/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
/// \note This api exists because std::string will constrained by ABI compile macro but char don't.
/// \return Shared pointer to the current schema
std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file);
/// \brief Function to create a SchemaObj
/// \param[in] schema_file Path of schema file
/// \return Shared pointer to the current schema
inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
return SchemaCharIF(StringToChar(schema_file));
}
class AlbumDataset : public Dataset {
public:
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
@ -373,7 +565,6 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
const std::shared_ptr<DatasetCache> &cache = nullptr) {
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
}
} // namespace dataset
} // namespace mindspore

View File

@ -17,15 +17,17 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_EXECUTE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_EXECUTE_H_
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "include/api/context.h"
#include "include/api/types.h"
#include "include/constants.h"
#include "dataset/include/transforms.h"
#include "include/transforms.h"
namespace mindspore {
namespace dataset {
class DeviceResource;
// class to run tensor operations in eager mode
class Execute {
@ -34,7 +36,7 @@ class Execute {
// FIXME - Temporarily overload Execute to support both TensorOperation and TensorTransform
explicit Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
explicit Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
// explicit Execute(TensorTransform op, MapTargetDevice deviceType = MapTargetDevice::KCpu);
explicit Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
explicit Execute(TensorTransform *op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
explicit Execute(std::vector<std::shared_ptr<TensorOperation>> ops,
@ -62,14 +64,23 @@ class Execute {
Status DeviceMemoryRelease();
std::string AippCfgGenerator();
private:
Status ParseTransforms_();
Status validate_device_();
std::vector<std::shared_ptr<TensorTransform>> transforms_;
std::vector<std::shared_ptr<TensorOperation>> ops_;
MapTargetDevice device_type_;
std::shared_ptr<DeviceResource> device_resource_;
struct ExtraInfo;
std::shared_ptr<ExtraInfo> info_;
};
} // namespace dataset

View File

@ -37,6 +37,7 @@ class Tensor;
class NativeRuntimeContext;
class IteratorConsumer;
class PullBasedIteratorConsumer;
class Dataset;
@ -80,7 +81,7 @@ class Iterator {
/// \note Type of return data is a vector(without column name).
/// \param[out] row - the output tensor row.
/// \return - a Status error code, returns OK if no error encountered.
Status GetNextRow(MSTensorVec *row);
virtual Status GetNextRow(MSTensorVec *row);
/// \brief Function to shut down the data pipeline.
void Stop();
@ -131,6 +132,35 @@ class Iterator {
std::unique_ptr<NativeRuntimeContext> runtime_context_;
IteratorConsumer *consumer_;
};
class PullIterator : public Iterator {
public:
/// \brief Constructor
PullIterator();
/// \brief Function to get next row from the data pipeline.
/// \note Type of return data is a vector(without column name).
/// \param[out] row - the output tensor row.
/// \return Returns true if no error encountered else false.
Status GetNextRow(MSTensorVec *row) override;
/// \brief Function to get specified rows from the data pipeline.
/// \note Type of return data is a vector(without column name).
/// \note This behavior is subject to change
/// \param[in] num_rows - the number of rows to fetch.
/// \param[out] row - the output tensor row.
/// \return Returns true if no error encountered else false.
Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *row);
/// \brief Method for building and launching the pipeline.
/// \note Consider making this function protected.
/// \param[in] ds - The root node that calls the function
/// \return - a Status error code, returns OK if no error encountered.
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
private:
std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_

View File

@ -18,14 +18,12 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
#include <memory>
#include <string>
#include <vector>
#include "include/api/status.h"
namespace mindspore {
namespace dataset {
// Forward declare
class SamplerObj;
// Abstract class to represent a sampler in the data pipeline.
@ -33,7 +31,20 @@ class SamplerObj;
/// \brief An abstract base class to represent a sampler in the data pipeline.
class Sampler : std::enable_shared_from_this<Sampler> {
friend class AlbumDataset;
friend class CelebADataset;
friend class Cifar10Dataset;
friend class Cifar100Dataset;
friend class CLUEDataset;
friend class CocoDataset;
friend class CSVDataset;
friend class ImageFolderDataset;
friend class ManifestDataset;
friend class MindDataDataset;
friend class MnistDataset;
friend class RandomDataDataset;
friend class TextFileDataset;
friend class TFRecordDataset;
friend class VOCDataset;
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -57,7 +68,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
/// \brief A class to represent a Distributed Sampler in the data pipeline.
/// \notes A Sampler that accesses a shard of the dataset.
class DistributedSampler : public Sampler {
class DistributedSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -93,7 +104,7 @@ class DistributedSampler : public Sampler {
/// \brief A class to represent a PK Sampler in the data pipeline.
/// \notes Samples K elements for each P class in the dataset.
/// This will sample all classes.
class PKSampler : public Sampler {
class PKSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -119,7 +130,7 @@ class PKSampler : public Sampler {
/// \brief A class to represent a Random Sampler in the data pipeline.
/// \notes Samples the elements randomly.
class RandomSampler : public Sampler {
class RandomSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -143,7 +154,7 @@ class RandomSampler : public Sampler {
/// \brief A class to represent a Sequential Sampler in the data pipeline.
/// \notes Samples the dataset elements sequentially, same as not having a sampler.
class SequentialSampler : public Sampler {
class SequentialSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -190,7 +201,7 @@ class SubsetSampler : public Sampler {
/// \brief A class to represent a Subset Random Sampler in the data pipeline.
/// \notes Samples the elements randomly from a sequence of indices.
class SubsetRandomSampler : public SubsetSampler {
class SubsetRandomSampler final : public SubsetSampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -211,7 +222,7 @@ class SubsetRandomSampler : public SubsetSampler {
/// \brief A class to represent a Weighted Random Sampler in the data pipeline.
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
/// weights (probabilities).
class WeightedRandomSampler : public Sampler {
class WeightedRandomSampler final : public Sampler {
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
public:
@ -234,6 +245,7 @@ class WeightedRandomSampler : public Sampler {
int64_t num_samples_;
bool replacement_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,9 +17,11 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/constants.h"
@ -29,10 +31,32 @@ namespace dataset {
class TensorOperation;
// We need the following two groups of forward declaration to friend the class in class TensorTransform.
namespace transforms {
class Compose;
class RandomApply;
class RandomChoice;
} // namespace transforms
namespace vision {
class BoundingBoxAugment;
class RandomSelectSubpolicy;
class UniformAugment;
} // namespace vision
// Abstract class to represent a tensor transform operation in the data pipeline.
/// \class TensorTransform transforms.h
/// \brief A base class to represent a tensor transform operation in the data pipeline.
class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
friend class Dataset;
friend class Execute;
friend class transforms::Compose;
friend class transforms::RandomApply;
friend class transforms::RandomChoice;
friend class vision::BoundingBoxAugment;
friend class vision::RandomSelectSubpolicy;
friend class vision::UniformAugment;
public:
/// \brief Constructor
TensorTransform() {}
@ -40,6 +64,7 @@ class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
/// \brief Destructor
~TensorTransform() = default;
protected:
/// \brief Pure virtual function to convert a TensorTransform class into a IR TensorOperation object.
/// \return shared pointer to the newly created TensorOperation.
virtual std::shared_ptr<TensorOperation> Parse() = 0;
@ -55,17 +80,22 @@ namespace transforms {
/// \brief Compose Op.
/// \notes Compose a list of transforms into a single transform.
class Compose : public TensorTransform {
class Compose final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] transforms A vector of transformations to be applied.
/// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
explicit Compose(const std::vector<TensorTransform *> &transforms);
/// \brief Constructor.
/// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
explicit Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
/// \brief Constructor.
/// \param[in] transforms A vector of TensorTransform objects to be applied.
explicit Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
/// \brief Destructor
~Compose() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -78,7 +108,7 @@ class Compose : public TensorTransform {
/// \brief Duplicate Op.
/// \notes Duplicate the input tensor to a new output tensor.
/// The input tensor is carried over to the output list.
class Duplicate : public TensorTransform {
class Duplicate final : public TensorTransform {
public:
/// \brief Constructor.
Duplicate();
@ -86,6 +116,7 @@ class Duplicate : public TensorTransform {
/// \brief Destructor
~Duplicate() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -93,7 +124,7 @@ class Duplicate : public TensorTransform {
/// \brief OneHot Op.
/// \notes Convert the labels into OneHot format.
class OneHot : public TensorTransform {
class OneHot final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] num_classes number of classes.
@ -102,6 +133,7 @@ class OneHot : public TensorTransform {
/// \brief Destructor
~OneHot() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -113,18 +145,25 @@ class OneHot : public TensorTransform {
/// \brief RandomApply Op.
/// \notes Randomly perform a series of transforms with a given probability.
class RandomApply : public TensorTransform {
class RandomApply final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] transforms A vector of transformations to be applied.
/// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
/// \param[in] prob The probability to apply the transformation list (default=0.5)
explicit RandomApply(const std::vector<TensorTransform *> &transforms, double prob = 0.5);
/// \brief Constructor.
/// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
/// \param[in] prob The probability to apply the transformation list (default=0.5)
explicit RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob = 0.5);
/// \brief Constructor.
/// \param[in] transforms A vector of TensorTransform objects to be applied.
/// \param[in] prob The probability to apply the transformation list (default=0.5)
explicit RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob = 0.5);
/// \brief Destructor
~RandomApply() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -136,17 +175,22 @@ class RandomApply : public TensorTransform {
/// \brief RandomChoice Op.
/// \notes Randomly selects one transform from a list of transforms to perform operation.
class RandomChoice : public TensorTransform {
class RandomChoice final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] transforms A vector of transformations to be chosen from to apply.
/// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
explicit RandomChoice(const std::vector<TensorTransform *> &transforms);
/// \brief Constructor.
/// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
explicit RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
/// \brief Constructor.
/// \param[in] transforms A vector of TensorTransform objects to be applied.
explicit RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
/// \brief Destructor
~RandomChoice() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -158,7 +202,7 @@ class RandomChoice : public TensorTransform {
/// \brief TypeCast Op.
/// \notes Tensor operation to cast to a given MindSpore data type.
class TypeCast : public TensorTransform {
class TypeCast final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] data_type mindspore.dtype to be cast to.
@ -169,6 +213,7 @@ class TypeCast : public TensorTransform {
/// \brief Destructor
~TypeCast() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -181,7 +226,7 @@ class TypeCast : public TensorTransform {
/// \brief Unique Op.
/// \notes Return an output tensor containing all the unique elements of the input tensor in
/// the same order that they occur in the input tensor.
class Unique : public TensorTransform {
class Unique final : public TensorTransform {
public:
/// \brief Constructor.
Unique();
@ -189,6 +234,7 @@ class Unique : public TensorTransform {
/// \brief Destructor
~Unique() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -22,6 +22,7 @@
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "include/constants.h"
#include "include/transforms.h"
@ -36,7 +37,7 @@ class RotateOperation;
/// \brief Affine TensorTransform.
/// \notes Apply affine transform on input image.
class Affine : public TensorTransform {
class Affine final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] degrees The degrees to rotate the image by
@ -64,9 +65,10 @@ class Affine : public TensorTransform {
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief CenterCrop TensorTransform.
/// \notes Crops the input image at the center to the given size.
class CenterCrop : public TensorTransform {
class CenterCrop final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] size A vector representing the output size of the cropped image.
@ -77,6 +79,7 @@ class CenterCrop : public TensorTransform {
/// \brief Destructor.
~CenterCrop() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -106,7 +109,7 @@ class RGB2GRAY : public TensorTransform {
/// \brief Crop TensorTransform.
/// \notes Crop an image based on location and crop size
class Crop : public TensorTransform {
class Crop final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] coordinates Starting location of crop. Must be a vector of two values, in the form of {x_coor, y_coor}
@ -118,6 +121,7 @@ class Crop : public TensorTransform {
/// \brief Destructor.
~Crop() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -129,7 +133,7 @@ class Crop : public TensorTransform {
/// \brief Decode TensorTransform.
/// \notes Decode the input image in RGB mode.
class Decode : public TensorTransform {
class Decode final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] rgb A boolean of whether to decode in RGB mode or not.
@ -138,6 +142,7 @@ class Decode : public TensorTransform {
/// \brief Destructor.
~Decode() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -151,7 +156,7 @@ class Decode : public TensorTransform {
/// \brief Normalize TensorTransform.
/// \notes Normalize the input image with respect to mean and standard deviation.
class Normalize : public TensorTransform {
class Normalize final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] mean A vector of mean values for each channel, w.r.t channel order.
@ -163,16 +168,21 @@ class Normalize : public TensorTransform {
/// \brief Destructor.
~Normalize() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
class RandomAffine : public TensorTransform {
/// \brief RandomAffine TensorTransform.
/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
class RandomAffine final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] degrees A float vector of size 2, representing the starting and ending degree
@ -210,7 +220,7 @@ class RandomAffine : public TensorTransform {
/// \brief Resize TensorTransform.
/// \notes Resize the input image to the given size.
class Resize : public TensorTransform {
class Resize final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] size A vector representing the output size of the resized image.
@ -222,6 +232,7 @@ class Resize : public TensorTransform {
/// \brief Destructor.
~Resize() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
@ -235,7 +246,7 @@ class Resize : public TensorTransform {
/// \brief Rotate TensorTransform.
/// \notes Rotate the input image using a specified angle id.
class Rotate : public TensorTransform {
class Rotate final : public TensorTransform {
public:
/// \brief Constructor.
Rotate();
@ -243,6 +254,7 @@ class Rotate : public TensorTransform {
/// \brief Destructor.
~Rotate() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;

View File

@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.14.1)
project(testlenet)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -fPIC -std=c++17")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")

View File

@ -33,13 +33,13 @@
using mindspore::dataset::Dataset;
using mindspore::dataset::Iterator;
using mindspore::dataset::Mnist;
using mindspore::dataset::TensorOperation;
using mindspore::dataset::TensorTransform;
int main(int argc, char **argv) {
std::string folder_path = "./testMnistData/";
std::shared_ptr<Dataset> ds = Mnist(folder_path, "all");
std::shared_ptr<TensorOperation> resize = mindspore::dataset::vision::Resize({32, 32});
std::shared_ptr<TensorTransform> resize(new mindspore::dataset::vision::Resize({32, 32}));
ds = ds->Map({resize});
ds = ds->Shuffle(2);