forked from mindspore-Ecosystem/mindspore
!13378 【MD】【Task】add data help for MD lite
From: @xulei2020 Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
888b2e19ee
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONSTANTS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONSTANTS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
@ -27,7 +27,7 @@ using uchar = unsigned char;
|
|||
using dsize_t = int64_t;
|
||||
|
||||
// Target devices to perform map operation
|
||||
enum class MapTargetDevice { kCpu, kGpu, kDvpp };
|
||||
enum class MapTargetDevice { kCpu, kGpu, kAscend310 };
|
||||
|
||||
// Possible dataset types for holding the data and client type
|
||||
enum class DatasetType { kUnknown, kArrow, kTf };
|
||||
|
@ -71,6 +71,9 @@ enum class NormalizeForm {
|
|||
kNfkd,
|
||||
};
|
||||
|
||||
// Possible values for SamplingStrategy
|
||||
enum class SamplingStrategy { kRandom = 0, kEdgeWeight = 1 };
|
||||
|
||||
// convenience functions for 32bit int bitmask
|
||||
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }
|
||||
|
||||
|
@ -84,7 +87,7 @@ constexpr int64_t kDeMaxFreq = std::numeric_limits<int64_t>::max(); // 92233720
|
|||
constexpr int64_t kDeMaxTopk = std::numeric_limits<int64_t>::max();
|
||||
|
||||
constexpr uint32_t kCfgRowsPerBuffer = 1;
|
||||
constexpr uint32_t kCfgParallelWorkers = 4;
|
||||
constexpr uint32_t kCfgParallelWorkers = 8;
|
||||
constexpr uint32_t kCfgWorkerConnectorSize = 16;
|
||||
constexpr uint32_t kCfgOpConnectorSize = 16;
|
||||
constexpr int32_t kCfgDefaultRankId = -1;
|
||||
|
@ -106,4 +109,4 @@ using row_id_type = int64_t;
|
|||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONSTANTS_H_
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONSTANTS_H_
|
||||
|
|
|
@ -0,0 +1,448 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATA_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATA_HELPER_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \brief Simple class to do data manipulation, contains helper function to update json files in dataset
|
||||
class DataHelper {
|
||||
public:
|
||||
/// \brief constructor
|
||||
DataHelper() {}
|
||||
|
||||
/// \brief Destructor
|
||||
~DataHelper() = default;
|
||||
|
||||
/// \brief Create an Album dataset while taking in a path to a image folder
|
||||
/// Creates the output directory if doesn't exist
|
||||
/// \param[in] in_dir Image folder directory that takes in images
|
||||
/// \param[in] out_dir Directory containing output json files
|
||||
Status CreateAlbum(const std::string &in_dir, const std::string &out_dir) {
|
||||
return CreateAlbumIF(StringToChar(in_dir), StringToChar(out_dir));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of string values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional input for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<std::string> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), VectorStringToChar(value), StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of bool values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<bool> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of int8 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int8_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of uint8 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint8_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of int16 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int16_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of uint16 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint16_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of int32 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int32_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of uint32 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint32_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of int64 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<int64_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of uint64 values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<uint64_t> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of float values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<float> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a vector of double values
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value array to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateArray(const std::string &in_file, const std::string &key, const std::vector<double> &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateArrayIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a string value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const std::string &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), StringToChar(value), StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a bool value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const bool &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an int8 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const int8_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an uint8 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const uint8_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an int16 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const int16_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an uint16 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const uint16_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an int32 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const int32_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an uint32 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const uint32_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an int64 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const int64_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with an uint64 value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const uint64_t &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a float value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const float &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Update a json file field with a double value
|
||||
/// \param in_file The input file name to read in
|
||||
/// \param key Key of field to write to
|
||||
/// \param value Value to write to file
|
||||
/// \param out_file Optional parameter for output file path, will write to input file if not specified
|
||||
/// \return Status The status code returned
|
||||
Status UpdateValue(const std::string &in_file, const std::string &key, const double &value,
|
||||
const std::string &out_file = "") {
|
||||
return UpdateValueIF(StringToChar(in_file), StringToChar(key), value, StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief Template function to write tensor to file
|
||||
/// \param[in] in_file File to write to
|
||||
/// \param[in] data Array of type T values
|
||||
/// \return Status The status code returned
|
||||
template <typename T>
|
||||
Status WriteBinFile(const std::string &in_file, const std::vector<T> &data) {
|
||||
try {
|
||||
std::ofstream o(in_file, std::ios::binary | std::ios::out);
|
||||
if (!o.is_open()) {
|
||||
return Status(kMDUnexpectedError, "Error opening Bin file to write");
|
||||
}
|
||||
size_t length = data.size();
|
||||
o.write(reinterpret_cast<const char *>(&data[0]), std::streamsize(length * sizeof(T)));
|
||||
o.close();
|
||||
}
|
||||
// Catch any exception and convert to Status return code
|
||||
catch (const std::exception &err) {
|
||||
return Status(kMDUnexpectedError, "Write bin file failed ");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Write pointer to bin, use pointer to avoid memcpy
|
||||
/// \param[in] in_file File name to write to
|
||||
/// \param[in] data Pointer to data
|
||||
/// \param[in] length Length of values to write from pointer
|
||||
/// \return Status The status code returned
|
||||
template <typename T>
|
||||
Status WriteBinFile(const std::string &in_file, T *data, size_t length) {
|
||||
try {
|
||||
std::ofstream o(in_file, std::ios::binary | std::ios::out);
|
||||
if (!o.is_open()) {
|
||||
return Status(kMDUnexpectedError, "Error opening Bin file to write");
|
||||
}
|
||||
o.write(reinterpret_cast<const char *>(data), std::streamsize(length * sizeof(T)));
|
||||
o.close();
|
||||
}
|
||||
// Catch any exception and convert to Status return code
|
||||
catch (const std::exception &err) {
|
||||
return Status(kMDUnexpectedError, "Write bin file failed ");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Helper function to copy content of a tensor to buffer
|
||||
/// \note This function iterates over the tensor in bytes, since
|
||||
/// \param[in] tensor_addr The memory held by a tensor
|
||||
/// \param[in] tensor_size The amount of data in bytes in tensor_addr, e.g. tensor->SizeInBytes()
|
||||
/// \param[out] addr The address to copy tensor data to
|
||||
/// \param[in] buffer_size The buffer size of addr
|
||||
/// \return The size of the tensor (bytes copied
|
||||
size_t DumpData(const unsigned char *tensor_addr, const size_t &tensor_size, void *addr, const size_t &buffer_size);
|
||||
|
||||
/// \brief Helper function to delete key in json file
|
||||
/// note This function will return okay even if key not found
|
||||
/// \param[in] in_file Json file to remove key from
|
||||
/// \param[in] key The key to remove
|
||||
/// \return Status The status code returned
|
||||
Status RemoveKey(const std::string &in_file, const std::string &key, const std::string &out_file = "") {
|
||||
return RemoveKeyIF(StringToChar(in_file), StringToChar(key), StringToChar(out_file));
|
||||
}
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
/// \param out - The output stream to write output to
|
||||
void Print(std::ostream &out) const;
|
||||
|
||||
/// \brief << Stream output operator overload
|
||||
/// \notes This allows you to write the debug print info using stream operators
|
||||
/// \param out Reference to the output stream being overloaded
|
||||
/// \param ds Reference to the DataSchema to display
|
||||
/// \return The output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, const DataHelper &dh) {
|
||||
dh.Print(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper function for dual ABI support
|
||||
Status CreateAlbumIF(const std::vector<char> &in_dir, const std::vector<char> &out_dir);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<std::vector<char>> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<bool> &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<int8_t> &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<uint8_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<int16_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<uint16_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<int32_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<uint32_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<int64_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key,
|
||||
const std::vector<uint64_t> &value, const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<float> &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateArrayIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<double> &value,
|
||||
const std::vector<char> &out_file);
|
||||
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<char> &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const bool &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int8_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint8_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int16_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint16_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int32_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint32_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const int64_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const uint64_t &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const float &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status UpdateValueIF(const std::vector<char> &in_file, const std::vector<char> &key, const double &value,
|
||||
const std::vector<char> &out_file);
|
||||
Status RemoveKeyIF(const std::vector<char> &in_file, const std::vector<char> &key, const std::vector<char> &out_file);
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATA_HELPER_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <sys/stat.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
@ -30,6 +31,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/iterator.h"
|
||||
#include "include/samplers.h"
|
||||
#include "include/transforms.h"
|
||||
|
@ -39,11 +41,15 @@ namespace dataset {
|
|||
|
||||
class Tensor;
|
||||
class TensorShape;
|
||||
class TreeAdapter;
|
||||
class TreeAdapterLite;
|
||||
class TreeGetters;
|
||||
|
||||
class DatasetCache;
|
||||
class DatasetNode;
|
||||
|
||||
class Iterator;
|
||||
|
||||
class TensorOperation;
|
||||
class SchemaObj;
|
||||
class SamplerObj;
|
||||
|
@ -75,13 +81,13 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return dataset size. If failed, return -1
|
||||
int64_t GetDatasetSize(bool estimate = false);
|
||||
|
||||
// /// \brief Gets the output type
|
||||
// /// \return a vector of DataType. If failed, return an empty vector
|
||||
// std::vector<DataType> GetOutputTypes();
|
||||
/// \brief Gets the output type
|
||||
/// \return a vector of DataType. If failed, return an empty vector
|
||||
std::vector<mindspore::DataType> GetOutputTypes();
|
||||
|
||||
/// \brief Gets the output shape
|
||||
/// \return a vector of TensorShape. If failed, return an empty vector
|
||||
std::vector<TensorShape> GetOutputShapes();
|
||||
std::vector<std::vector<int64_t>> GetOutputShapes();
|
||||
|
||||
/// \brief Gets the batch size
|
||||
/// \return int64_t
|
||||
|
@ -110,6 +116,11 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
/// \return Shared pointer to the original object
|
||||
std::shared_ptr<Dataset> SetNumWorkers(int32_t num_workers);
|
||||
|
||||
/// \brief Function to create an PullBasedIterator over the Dataset
|
||||
/// \param[in] columns List of columns to be used to specify the order of columns
|
||||
/// \return Shared pointer to the Iterator
|
||||
std::shared_ptr<PullIterator> CreatePullBasedIterator(std::vector<std::vector<char>> columns = {});
|
||||
|
||||
/// \brief Function to create an Iterator over the Dataset pipeline
|
||||
/// \param[in] columns List of columns to be used to specify the order of columns
|
||||
/// \param[in] num_epochs Number of epochs to run through the pipeline, default -1 which means infinite epochs.
|
||||
|
@ -119,6 +130,41 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
return CreateIteratorCharIF(VectorStringToChar(columns), num_epochs);
|
||||
}
|
||||
|
||||
/// \brief Function to transfer data through a device.
|
||||
/// \notes If device is Ascend, features of data will be transferred one by one. The limitation
|
||||
/// of data transmission per time is 256M.
|
||||
/// \param[in] queue_name Channel name (default="", create new unique name).
|
||||
/// \param[in] device_type Type of device (default="", get from MSContext).
|
||||
/// \param[in] device_id id of device (default=1, get from MSContext).
|
||||
/// \param[in] num_epochs Number of epochs (default=-1, infinite epochs).
|
||||
/// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=true).
|
||||
/// \param[in] total_batches Number of batches to be sent to the device (default=0, all data).
|
||||
/// \param[in] create_data_info_queue Whether to create queue which stores types and shapes
|
||||
/// of data or not(default=false).
|
||||
/// \return Returns true if no error encountered else false.
|
||||
bool DeviceQueue(std::string queue_name = "", std::string device_type = "", int32_t device_id = 0,
|
||||
int32_t num_epochs = -1, bool send_epoch_end = true, int32_t total_batches = 0,
|
||||
bool create_data_info_queue = false) {
|
||||
return DeviceQueueCharIF(StringToChar(queue_name), StringToChar(device_type), device_id, num_epochs, send_epoch_end,
|
||||
total_batches, create_data_info_queue);
|
||||
}
|
||||
|
||||
/// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline
|
||||
/// \note Usage restrictions:
|
||||
/// 1. Supported dataset formats: 'mindrecord' only
|
||||
/// 2. To save the samples in order, set dataset's shuffle to false and num_files to 1.
|
||||
/// 3. Before calling the function, do not use batch operator, repeat operator or data augmentation operators
|
||||
/// with random attribute in map operator.
|
||||
/// 4. Mindrecord does not support bool, uint64, multi-dimensional uint8(drop dimension) nor
|
||||
/// multi-dimensional string.
|
||||
/// \param[in] file_name Path to dataset file
|
||||
/// \param[in] num_files Number of dataset files (default=1)
|
||||
/// \param[in] file_type Dataset format (default="mindrecord")
|
||||
/// \return Returns true if no error encountered else false
|
||||
bool Save(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord") {
|
||||
return SaveCharIF(StringToChar(dataset_path), num_files, StringToChar(dataset_type));
|
||||
}
|
||||
|
||||
/// \brief Function to create a BatchDataset
|
||||
/// \notes Combines batch_size number of consecutive rows into batches
|
||||
/// \param[in] batch_size The number of rows each batch is created with
|
||||
|
@ -131,8 +177,8 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
|
||||
/// \brief Function to create a MapDataset
|
||||
/// \notes Applies each operation in operations to this dataset
|
||||
/// \param[in] operations Vector of operations to be applied on the dataset. Operations are
|
||||
/// applied in the order they appear in this list
|
||||
/// \param[in] operations Vector of raw pointers to TensorTransform objects to be applied on the dataset. Operations
|
||||
/// are applied in the order they appear in this list
|
||||
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
|
||||
/// operation as input. The size of this list must match the number of
|
||||
/// input columns expected by the first operator. The default input_columns
|
||||
|
@ -160,6 +206,22 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
callbacks);
|
||||
}
|
||||
|
||||
/// \brief Function to create a MapDataset
|
||||
/// \notes Applies each operation in operations to this dataset
|
||||
/// \param[in] operations Vector of shared pointers to TensorTransform objects to be applied on the dataset.
|
||||
/// Operations are applied in the order they appear in this list
|
||||
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
|
||||
/// operation as input. The size of this list must match the number of
|
||||
/// input columns expected by the first operator. The default input_columns
|
||||
/// is the first column
|
||||
/// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
|
||||
/// This parameter is mandatory if len(input_columns) != len(output_columns)
|
||||
/// The size of this list must match the number of output columns of the
|
||||
/// last operation. The default output_columns will have the same
|
||||
/// name as the input columns, i.e., the columns will be replaced
|
||||
/// \param[in] project_columns A list of column names to project
|
||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current MapDataset
|
||||
std::shared_ptr<MapDataset> Map(std::vector<std::shared_ptr<TensorTransform>> operations,
|
||||
const std::vector<std::string> &input_columns = {},
|
||||
const std::vector<std::string> &output_columns = {},
|
||||
|
@ -176,6 +238,22 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
callbacks);
|
||||
}
|
||||
|
||||
/// \brief Function to create a MapDataset
|
||||
/// \notes Applies each operation in operations to this dataset
|
||||
/// \param[in] operations Vector of TensorTransform objects to be applied on the dataset. Operations are applied in
|
||||
/// the order they appear in this list
|
||||
/// \param[in] input_columns Vector of the names of the columns that will be passed to the first
|
||||
/// operation as input. The size of this list must match the number of
|
||||
/// input columns expected by the first operator. The default input_columns
|
||||
/// is the first column
|
||||
/// \param[in] output_columns Vector of names assigned to the columns outputted by the last operation
|
||||
/// This parameter is mandatory if len(input_columns) != len(output_columns)
|
||||
/// The size of this list must match the number of output columns of the
|
||||
/// last operation. The default output_columns will have the same
|
||||
/// name as the input columns, i.e., the columns will be replaced
|
||||
/// \param[in] project_columns A list of column names to project
|
||||
/// \param[in] cache Tensor cache to use. (default=nullptr which means no cache is used).
|
||||
/// \return Shared pointer to the current MapDataset
|
||||
std::shared_ptr<MapDataset> Map(const std::vector<std::reference_wrapper<TensorTransform>> operations,
|
||||
const std::vector<std::string> &input_columns = {},
|
||||
const std::vector<std::string> &output_columns = {},
|
||||
|
@ -221,6 +299,115 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
|
|||
|
||||
// Char interface(CharIF) of CreateIterator
|
||||
std::shared_ptr<Iterator> CreateIteratorCharIF(std::vector<std::vector<char>> columns, int32_t num_epochs);
|
||||
|
||||
// Char interface(CharIF) of DeviceQueue
|
||||
bool DeviceQueueCharIF(const std::vector<char> &queue_name, const std::vector<char> &device_type, int32_t device_id,
|
||||
int32_t num_epochs, bool send_epoch_end, int32_t total_batches, bool create_data_info_queue);
|
||||
|
||||
// Char interface(CharIF) of Save
|
||||
bool SaveCharIF(const std::vector<char> &dataset_path, int32_t num_files, const std::vector<char> &dataset_type);
|
||||
};
|
||||
|
||||
class SchemaObj {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SchemaObj(const std::string &schema_file = "") : SchemaObj(StringToChar(schema_file)) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~SchemaObj() = default;
|
||||
|
||||
/// \brief SchemaObj Init function
|
||||
/// \return bool true if schema initialization is successful
|
||||
Status Init();
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] ms_type Data type of the column(mindspore::DataType).
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, mindspore::DataType ms_type) {
|
||||
return add_column_char(StringToChar(name), ms_type);
|
||||
}
|
||||
|
||||
/// \brief Add new column to the schema with unknown shape of rank 1
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] ms_type Data type of the column(std::string).
|
||||
/// \param[in] shape Shape of the column.
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, const std::string &ms_type) {
|
||||
return add_column_char(StringToChar(name), StringToChar(ms_type));
|
||||
}
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] ms_type Data type of the column(mindspore::DataType).
|
||||
/// \param[in] shape Shape of the column.
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape) {
|
||||
return add_column_char(StringToChar(name), ms_type, shape);
|
||||
}
|
||||
|
||||
/// \brief Add new column to the schema
|
||||
/// \param[in] name Name of the column.
|
||||
/// \param[in] ms_type Data type of the column(std::string).
|
||||
/// \param[in] shape Shape of the column.
|
||||
/// \return Status code
|
||||
Status add_column(const std::string &name, const std::string &ms_type, const std::vector<int32_t> &shape) {
|
||||
return add_column_char(StringToChar(name), StringToChar(ms_type), shape);
|
||||
}
|
||||
|
||||
/// \brief Get a JSON string of the schema
|
||||
/// \return JSON string of the schema
|
||||
std::string to_json() { return CharToString(to_json_char()); }
|
||||
|
||||
/// \brief Get a JSON string of the schema
|
||||
std::string to_string() { return to_json(); }
|
||||
|
||||
/// \brief Set a new value to dataset_type
|
||||
void set_dataset_type(std::string dataset_type);
|
||||
|
||||
/// \brief Set a new value to num_rows
|
||||
void set_num_rows(int32_t num_rows);
|
||||
|
||||
/// \brief Get the current num_rows
|
||||
int32_t get_num_rows() const;
|
||||
|
||||
/// \brief Get schema file from JSON file
|
||||
/// \param[in] json_string Name of JSON file to be parsed.
|
||||
/// \return Status code
|
||||
Status FromJSONString(const std::string &json_string) { return FromJSONStringCharIF(StringToChar(json_string)); }
|
||||
|
||||
/// \brief Parse and add column information
|
||||
/// \param[in] json_string Name of JSON string for column dataset attribute information, decoded from schema file.
|
||||
/// \return Status code
|
||||
Status ParseColumnString(const std::string &json_string) {
|
||||
return ParseColumnStringCharIF(StringToChar(json_string));
|
||||
}
|
||||
|
||||
private:
|
||||
// Char constructor of SchemaObj
|
||||
explicit SchemaObj(const std::vector<char> &schema_file);
|
||||
|
||||
// Char interface of add_column
|
||||
Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type);
|
||||
|
||||
Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type);
|
||||
|
||||
Status add_column_char(const std::vector<char> &name, mindspore::DataType ms_type, const std::vector<int32_t> &shape);
|
||||
|
||||
Status add_column_char(const std::vector<char> &name, const std::vector<char> &ms_type,
|
||||
const std::vector<int32_t> &shape);
|
||||
|
||||
// Char interface of to_json
|
||||
const std::vector<char> to_json_char();
|
||||
|
||||
// Char interface of FromJSONString
|
||||
Status FromJSONStringCharIF(const std::vector<char> &json_string);
|
||||
|
||||
// Char interface of ParseColumnString
|
||||
Status ParseColumnStringCharIF(const std::vector<char> &json_string);
|
||||
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
class BatchDataset : public Dataset {
|
||||
|
@ -252,12 +439,17 @@ class ShuffleDataset : public Dataset {
|
|||
|
||||
/// \brief Function to create a SchemaObj
|
||||
/// \param[in] schema_file Path of schema file
|
||||
/// \note This api exists because std::string will constrained by ABI compile macro but char don't.
|
||||
/// \return Shared pointer to the current schema
|
||||
std::shared_ptr<SchemaObj> SchemaCharIF(const std::vector<char> &schema_file);
|
||||
|
||||
/// \brief Function to create a SchemaObj
|
||||
/// \param[in] schema_file Path of schema file
|
||||
/// \return Shared pointer to the current schema
|
||||
inline std::shared_ptr<SchemaObj> Schema(const std::string &schema_file = "") {
|
||||
return SchemaCharIF(StringToChar(schema_file));
|
||||
}
|
||||
|
||||
class AlbumDataset : public Dataset {
|
||||
public:
|
||||
AlbumDataset(const std::vector<char> &dataset_dir, const std::vector<char> &data_schema,
|
||||
|
@ -373,7 +565,6 @@ inline std::shared_ptr<MnistDataset> Mnist(const std::string &dataset_dir, const
|
|||
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
||||
return std::make_shared<MnistDataset>(StringToChar(dataset_dir), StringToChar(usage), sampler, cache);
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -17,15 +17,17 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_EXECUTE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_EXECUTE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/context.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/constants.h"
|
||||
#include "dataset/include/transforms.h"
|
||||
#include "include/transforms.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DeviceResource;
|
||||
// class to run tensor operations in eager mode
|
||||
class Execute {
|
||||
|
@ -34,7 +36,7 @@ class Execute {
|
|||
// FIXME - Temporarily overload Execute to support both TensorOperation and TensorTransform
|
||||
explicit Execute(std::shared_ptr<TensorOperation> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
||||
explicit Execute(std::shared_ptr<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
||||
// explicit Execute(TensorTransform op, MapTargetDevice deviceType = MapTargetDevice::KCpu);
|
||||
explicit Execute(std::reference_wrapper<TensorTransform> op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
||||
explicit Execute(TensorTransform *op, MapTargetDevice deviceType = MapTargetDevice::kCpu);
|
||||
|
||||
explicit Execute(std::vector<std::shared_ptr<TensorOperation>> ops,
|
||||
|
@ -62,14 +64,23 @@ class Execute {
|
|||
|
||||
Status DeviceMemoryRelease();
|
||||
|
||||
std::string AippCfgGenerator();
|
||||
|
||||
private:
|
||||
Status ParseTransforms_();
|
||||
|
||||
Status validate_device_();
|
||||
|
||||
std::vector<std::shared_ptr<TensorTransform>> transforms_;
|
||||
|
||||
std::vector<std::shared_ptr<TensorOperation>> ops_;
|
||||
|
||||
MapTargetDevice device_type_;
|
||||
|
||||
std::shared_ptr<DeviceResource> device_resource_;
|
||||
|
||||
struct ExtraInfo;
|
||||
std::shared_ptr<ExtraInfo> info_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -37,6 +37,7 @@ class Tensor;
|
|||
|
||||
class NativeRuntimeContext;
|
||||
class IteratorConsumer;
|
||||
class PullBasedIteratorConsumer;
|
||||
|
||||
class Dataset;
|
||||
|
||||
|
@ -80,7 +81,7 @@ class Iterator {
|
|||
/// \note Type of return data is a vector(without column name).
|
||||
/// \param[out] row - the output tensor row.
|
||||
/// \return - a Status error code, returns OK if no error encountered.
|
||||
Status GetNextRow(MSTensorVec *row);
|
||||
virtual Status GetNextRow(MSTensorVec *row);
|
||||
|
||||
/// \brief Function to shut down the data pipeline.
|
||||
void Stop();
|
||||
|
@ -131,6 +132,35 @@ class Iterator {
|
|||
std::unique_ptr<NativeRuntimeContext> runtime_context_;
|
||||
IteratorConsumer *consumer_;
|
||||
};
|
||||
|
||||
class PullIterator : public Iterator {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
PullIterator();
|
||||
|
||||
/// \brief Function to get next row from the data pipeline.
|
||||
/// \note Type of return data is a vector(without column name).
|
||||
/// \param[out] row - the output tensor row.
|
||||
/// \return Returns true if no error encountered else false.
|
||||
Status GetNextRow(MSTensorVec *row) override;
|
||||
|
||||
/// \brief Function to get specified rows from the data pipeline.
|
||||
/// \note Type of return data is a vector(without column name).
|
||||
/// \note This behavior is subject to change
|
||||
/// \param[in] num_rows - the number of rows to fetch.
|
||||
/// \param[out] row - the output tensor row.
|
||||
/// \return Returns true if no error encountered else false.
|
||||
Status GetRows(int32_t num_rows, std::vector<MSTensorVec> *row);
|
||||
|
||||
/// \brief Method for building and launching the pipeline.
|
||||
/// \note Consider making this function protected.
|
||||
/// \param[in] ds - The root node that calls the function
|
||||
/// \return - a Status error code, returns OK if no error encountered.
|
||||
Status BuildAndLaunchTree(std::shared_ptr<Dataset> ds);
|
||||
|
||||
private:
|
||||
std::unique_ptr<PullBasedIteratorConsumer> pull_consumer_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_ITERATOR_H_
|
||||
|
|
|
@ -18,14 +18,12 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// Forward declare
|
||||
class SamplerObj;
|
||||
|
||||
// Abstract class to represent a sampler in the data pipeline.
|
||||
|
@ -33,7 +31,20 @@ class SamplerObj;
|
|||
/// \brief An abstract base class to represent a sampler in the data pipeline.
|
||||
class Sampler : std::enable_shared_from_this<Sampler> {
|
||||
friend class AlbumDataset;
|
||||
friend class CelebADataset;
|
||||
friend class Cifar10Dataset;
|
||||
friend class Cifar100Dataset;
|
||||
friend class CLUEDataset;
|
||||
friend class CocoDataset;
|
||||
friend class CSVDataset;
|
||||
friend class ImageFolderDataset;
|
||||
friend class ManifestDataset;
|
||||
friend class MindDataDataset;
|
||||
friend class MnistDataset;
|
||||
friend class RandomDataDataset;
|
||||
friend class TextFileDataset;
|
||||
friend class TFRecordDataset;
|
||||
friend class VOCDataset;
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -57,7 +68,7 @@ class Sampler : std::enable_shared_from_this<Sampler> {
|
|||
|
||||
/// \brief A class to represent a Distributed Sampler in the data pipeline.
|
||||
/// \notes A Sampler that accesses a shard of the dataset.
|
||||
class DistributedSampler : public Sampler {
|
||||
class DistributedSampler final : public Sampler {
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -93,7 +104,7 @@ class DistributedSampler : public Sampler {
|
|||
/// \brief A class to represent a PK Sampler in the data pipeline.
|
||||
/// \notes Samples K elements for each P class in the dataset.
|
||||
/// This will sample all classes.
|
||||
class PKSampler : public Sampler {
|
||||
class PKSampler final : public Sampler {
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -119,7 +130,7 @@ class PKSampler : public Sampler {
|
|||
|
||||
/// \brief A class to represent a Random Sampler in the data pipeline.
|
||||
/// \notes Samples the elements randomly.
|
||||
class RandomSampler : public Sampler {
|
||||
class RandomSampler final : public Sampler {
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -143,7 +154,7 @@ class RandomSampler : public Sampler {
|
|||
|
||||
/// \brief A class to represent a Sequential Sampler in the data pipeline.
|
||||
/// \notes Samples the dataset elements sequentially, same as not having a sampler.
|
||||
class SequentialSampler : public Sampler {
|
||||
class SequentialSampler final : public Sampler {
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -190,7 +201,7 @@ class SubsetSampler : public Sampler {
|
|||
|
||||
/// \brief A class to represent a Subset Random Sampler in the data pipeline.
|
||||
/// \notes Samples the elements randomly from a sequence of indices.
|
||||
class SubsetRandomSampler : public SubsetSampler {
|
||||
class SubsetRandomSampler final : public SubsetSampler {
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -211,7 +222,7 @@ class SubsetRandomSampler : public SubsetSampler {
|
|||
/// \brief A class to represent a Weighted Random Sampler in the data pipeline.
|
||||
/// \notes Samples the elements from [0, len(weights) - 1] randomly with the given
|
||||
/// weights (probabilities).
|
||||
class WeightedRandomSampler : public Sampler {
|
||||
class WeightedRandomSampler final : public Sampler {
|
||||
friend std::shared_ptr<SamplerObj> SelectSampler(int64_t, bool, int32_t, int32_t);
|
||||
|
||||
public:
|
||||
|
@ -234,6 +245,7 @@ class WeightedRandomSampler : public Sampler {
|
|||
int64_t num_samples_;
|
||||
bool replacement_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,9 +17,11 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_TRANSFORMS_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/status.h"
|
||||
#include "include/constants.h"
|
||||
|
@ -29,10 +31,32 @@ namespace dataset {
|
|||
|
||||
class TensorOperation;
|
||||
|
||||
// We need the following two groups of forward declaration to friend the class in class TensorTransform.
|
||||
namespace transforms {
|
||||
class Compose;
|
||||
class RandomApply;
|
||||
class RandomChoice;
|
||||
} // namespace transforms
|
||||
|
||||
namespace vision {
|
||||
class BoundingBoxAugment;
|
||||
class RandomSelectSubpolicy;
|
||||
class UniformAugment;
|
||||
} // namespace vision
|
||||
|
||||
// Abstract class to represent a tensor transform operation in the data pipeline.
|
||||
/// \class TensorTransform transforms.h
|
||||
/// \brief A base class to represent a tensor transform operation in the data pipeline.
|
||||
class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
|
||||
friend class Dataset;
|
||||
friend class Execute;
|
||||
friend class transforms::Compose;
|
||||
friend class transforms::RandomApply;
|
||||
friend class transforms::RandomChoice;
|
||||
friend class vision::BoundingBoxAugment;
|
||||
friend class vision::RandomSelectSubpolicy;
|
||||
friend class vision::UniformAugment;
|
||||
|
||||
public:
|
||||
/// \brief Constructor
|
||||
TensorTransform() {}
|
||||
|
@ -40,6 +64,7 @@ class TensorTransform : public std::enable_shared_from_this<TensorTransform> {
|
|||
/// \brief Destructor
|
||||
~TensorTransform() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Pure virtual function to convert a TensorTransform class into a IR TensorOperation object.
|
||||
/// \return shared pointer to the newly created TensorOperation.
|
||||
virtual std::shared_ptr<TensorOperation> Parse() = 0;
|
||||
|
@ -55,17 +80,22 @@ namespace transforms {
|
|||
|
||||
/// \brief Compose Op.
|
||||
/// \notes Compose a list of transforms into a single transform.
|
||||
class Compose : public TensorTransform {
|
||||
class Compose final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of transformations to be applied.
|
||||
/// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
|
||||
explicit Compose(const std::vector<TensorTransform *> &transforms);
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
|
||||
explicit Compose(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of TensorTransform objects to be applied.
|
||||
explicit Compose(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
|
||||
|
||||
/// \brief Destructor
|
||||
~Compose() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -78,7 +108,7 @@ class Compose : public TensorTransform {
|
|||
/// \brief Duplicate Op.
|
||||
/// \notes Duplicate the input tensor to a new output tensor.
|
||||
/// The input tensor is carried over to the output list.
|
||||
class Duplicate : public TensorTransform {
|
||||
class Duplicate final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
Duplicate();
|
||||
|
@ -86,6 +116,7 @@ class Duplicate : public TensorTransform {
|
|||
/// \brief Destructor
|
||||
~Duplicate() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -93,7 +124,7 @@ class Duplicate : public TensorTransform {
|
|||
|
||||
/// \brief OneHot Op.
|
||||
/// \notes Convert the labels into OneHot format.
|
||||
class OneHot : public TensorTransform {
|
||||
class OneHot final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] num_classes number of classes.
|
||||
|
@ -102,6 +133,7 @@ class OneHot : public TensorTransform {
|
|||
/// \brief Destructor
|
||||
~OneHot() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -113,18 +145,25 @@ class OneHot : public TensorTransform {
|
|||
|
||||
/// \brief RandomApply Op.
|
||||
/// \notes Randomly perform a series of transforms with a given probability.
|
||||
class RandomApply : public TensorTransform {
|
||||
class RandomApply final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of transformations to be applied.
|
||||
/// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
|
||||
/// \param[in] prob The probability to apply the transformation list (default=0.5)
|
||||
explicit RandomApply(const std::vector<TensorTransform *> &transforms, double prob = 0.5);
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
|
||||
/// \param[in] prob The probability to apply the transformation list (default=0.5)
|
||||
explicit RandomApply(const std::vector<std::shared_ptr<TensorTransform>> &transforms, double prob = 0.5);
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of TensorTransform objects to be applied.
|
||||
/// \param[in] prob The probability to apply the transformation list (default=0.5)
|
||||
explicit RandomApply(const std::vector<std::reference_wrapper<TensorTransform>> &transforms, double prob = 0.5);
|
||||
|
||||
/// \brief Destructor
|
||||
~RandomApply() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -136,17 +175,22 @@ class RandomApply : public TensorTransform {
|
|||
|
||||
/// \brief RandomChoice Op.
|
||||
/// \notes Randomly selects one transform from a list of transforms to perform operation.
|
||||
class RandomChoice : public TensorTransform {
|
||||
class RandomChoice final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of transformations to be chosen from to apply.
|
||||
/// \param[in] transforms A vector of raw pointers to TensorTransform objects to be applied.
|
||||
explicit RandomChoice(const std::vector<TensorTransform *> &transforms);
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of shared pointers to TensorTransform objects to be applied.
|
||||
explicit RandomChoice(const std::vector<std::shared_ptr<TensorTransform>> &transforms);
|
||||
/// \brief Constructor.
|
||||
/// \param[in] transforms A vector of TensorTransform objects to be applied.
|
||||
explicit RandomChoice(const std::vector<std::reference_wrapper<TensorTransform>> &transforms);
|
||||
|
||||
/// \brief Destructor
|
||||
~RandomChoice() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -158,7 +202,7 @@ class RandomChoice : public TensorTransform {
|
|||
|
||||
/// \brief TypeCast Op.
|
||||
/// \notes Tensor operation to cast to a given MindSpore data type.
|
||||
class TypeCast : public TensorTransform {
|
||||
class TypeCast final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] data_type mindspore.dtype to be cast to.
|
||||
|
@ -169,6 +213,7 @@ class TypeCast : public TensorTransform {
|
|||
/// \brief Destructor
|
||||
~TypeCast() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -181,7 +226,7 @@ class TypeCast : public TensorTransform {
|
|||
/// \brief Unique Op.
|
||||
/// \notes Return an output tensor containing all the unique elements of the input tensor in
|
||||
/// the same order that they occur in the input tensor.
|
||||
class Unique : public TensorTransform {
|
||||
class Unique final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
Unique();
|
||||
|
@ -189,6 +234,7 @@ class Unique : public TensorTransform {
|
|||
/// \brief Destructor
|
||||
~Unique() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -22,6 +22,7 @@
|
|||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "include/api/status.h"
|
||||
#include "include/constants.h"
|
||||
#include "include/transforms.h"
|
||||
|
||||
|
@ -36,7 +37,7 @@ class RotateOperation;
|
|||
|
||||
/// \brief Affine TensorTransform.
|
||||
/// \notes Apply affine transform on input image.
|
||||
class Affine : public TensorTransform {
|
||||
class Affine final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] degrees The degrees to rotate the image by
|
||||
|
@ -64,9 +65,10 @@ class Affine : public TensorTransform {
|
|||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief CenterCrop TensorTransform.
|
||||
/// \notes Crops the input image at the center to the given size.
|
||||
class CenterCrop : public TensorTransform {
|
||||
class CenterCrop final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] size A vector representing the output size of the cropped image.
|
||||
|
@ -77,6 +79,7 @@ class CenterCrop : public TensorTransform {
|
|||
/// \brief Destructor.
|
||||
~CenterCrop() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -106,7 +109,7 @@ class RGB2GRAY : public TensorTransform {
|
|||
|
||||
/// \brief Crop TensorTransform.
|
||||
/// \notes Crop an image based on location and crop size
|
||||
class Crop : public TensorTransform {
|
||||
class Crop final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] coordinates Starting location of crop. Must be a vector of two values, in the form of {x_coor, y_coor}
|
||||
|
@ -118,6 +121,7 @@ class Crop : public TensorTransform {
|
|||
/// \brief Destructor.
|
||||
~Crop() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -129,7 +133,7 @@ class Crop : public TensorTransform {
|
|||
|
||||
/// \brief Decode TensorTransform.
|
||||
/// \notes Decode the input image in RGB mode.
|
||||
class Decode : public TensorTransform {
|
||||
class Decode final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] rgb A boolean of whether to decode in RGB mode or not.
|
||||
|
@ -138,6 +142,7 @@ class Decode : public TensorTransform {
|
|||
/// \brief Destructor.
|
||||
~Decode() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -151,7 +156,7 @@ class Decode : public TensorTransform {
|
|||
|
||||
/// \brief Normalize TensorTransform.
|
||||
/// \notes Normalize the input image with respect to mean and standard deviation.
|
||||
class Normalize : public TensorTransform {
|
||||
class Normalize final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] mean A vector of mean values for each channel, w.r.t channel order.
|
||||
|
@ -163,16 +168,21 @@ class Normalize : public TensorTransform {
|
|||
/// \brief Destructor.
|
||||
~Normalize() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
std::shared_ptr<TensorOperation> Parse(const MapTargetDevice &env) override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
class RandomAffine : public TensorTransform {
|
||||
/// \brief RandomAffine TensorTransform.
|
||||
/// \notes Applies a Random Affine transformation on input image in RGB or Greyscale mode.
|
||||
class RandomAffine final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] degrees A float vector of size 2, representing the starting and ending degree
|
||||
|
@ -210,7 +220,7 @@ class RandomAffine : public TensorTransform {
|
|||
|
||||
/// \brief Resize TensorTransform.
|
||||
/// \notes Resize the input image to the given size.
|
||||
class Resize : public TensorTransform {
|
||||
class Resize final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] size A vector representing the output size of the resized image.
|
||||
|
@ -222,6 +232,7 @@ class Resize : public TensorTransform {
|
|||
/// \brief Destructor.
|
||||
~Resize() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
@ -235,7 +246,7 @@ class Resize : public TensorTransform {
|
|||
|
||||
/// \brief Rotate TensorTransform.
|
||||
/// \notes Rotate the input image using a specified angle id.
|
||||
class Rotate : public TensorTransform {
|
||||
class Rotate final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
Rotate();
|
||||
|
@ -243,6 +254,7 @@ class Rotate : public TensorTransform {
|
|||
/// \brief Destructor.
|
||||
~Rotate() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
cmake_minimum_required(VERSION 3.14.1)
|
||||
project(testlenet)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -fPIC")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -fPIC -std=c++17")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
|
||||
|
||||
|
|
|
@ -33,13 +33,13 @@
|
|||
using mindspore::dataset::Dataset;
|
||||
using mindspore::dataset::Iterator;
|
||||
using mindspore::dataset::Mnist;
|
||||
using mindspore::dataset::TensorOperation;
|
||||
using mindspore::dataset::TensorTransform;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
std::string folder_path = "./testMnistData/";
|
||||
std::shared_ptr<Dataset> ds = Mnist(folder_path, "all");
|
||||
|
||||
std::shared_ptr<TensorOperation> resize = mindspore::dataset::vision::Resize({32, 32});
|
||||
std::shared_ptr<TensorTransform> resize(new mindspore::dataset::vision::Resize({32, 32}));
|
||||
ds = ds->Map({resize});
|
||||
|
||||
ds = ds->Shuffle(2);
|
||||
|
|
Loading…
Reference in New Issue