forked from mindspore-Ecosystem/mindspore
!22160 dataset: add more param check to avoid security problem
Merge pull request !22160 from ms_yan/sync_fix
This commit is contained in:
commit
cdcfbd2616
|
@ -141,7 +141,7 @@ class MS_API Status {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// api without std::string
|
// api without std::string
|
||||||
explicit Status(enum StatusCode status_code, const std::vector<char> &status_msg);
|
Status(enum StatusCode status_code, const std::vector<char> &status_msg);
|
||||||
Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra);
|
Status(const enum StatusCode code, int line_of_code, const char *file_name, const std::vector<char> &extra);
|
||||||
std::vector<char> ToCString() const;
|
std::vector<char> ToCString() const;
|
||||||
std::vector<char> GetErrDescriptionChar() const;
|
std::vector<char> GetErrDescriptionChar() const;
|
||||||
|
|
|
@ -454,6 +454,10 @@ std::vector<float> AippStdFilter(const std::vector<uint32_t> &normalize_para) {
|
||||||
if (normalize_para.size() == 6) { // If Normalize operator exist
|
if (normalize_para.size() == 6) { // If Normalize operator exist
|
||||||
auto zeros = std::find(std::begin(normalize_para), std::end(normalize_para), 0);
|
auto zeros = std::find(std::begin(normalize_para), std::end(normalize_para), 0);
|
||||||
if (zeros == std::end(normalize_para)) {
|
if (zeros == std::end(normalize_para)) {
|
||||||
|
if (std::any_of(normalize_para.begin() + 3, normalize_para.end(), [](uint32_t i) { return i == 0; })) {
|
||||||
|
MS_LOG(ERROR) << "value in normalize para got 0.";
|
||||||
|
return {};
|
||||||
|
}
|
||||||
std::transform(normalize_para.begin() + 3, normalize_para.end(), std::back_inserter(aipp_std),
|
std::transform(normalize_para.begin() + 3, normalize_para.end(), std::back_inserter(aipp_std),
|
||||||
[](uint32_t i) { return 10000 / static_cast<float>(i); });
|
[](uint32_t i) { return 10000 / static_cast<float>(i); });
|
||||||
} else { // If 0 occurs in std vector
|
} else { // If 0 occurs in std vector
|
||||||
|
|
|
@ -83,7 +83,7 @@ std::unordered_map<int32_t, std::vector<pid_t>> toIntMap(const py::dict input_di
|
||||||
|
|
||||||
std::pair<int64_t, int64_t> toIntPair(const py::tuple tuple) {
|
std::pair<int64_t, int64_t> toIntPair(const py::tuple tuple) {
|
||||||
std::pair<int64_t, int64_t> pair;
|
std::pair<int64_t, int64_t> pair;
|
||||||
if (!tuple.empty()) {
|
if (tuple.size() == 2) {
|
||||||
pair = std::make_pair(toInt64((tuple)[0]), toInt64((tuple)[1]));
|
pair = std::make_pair(toInt64((tuple)[0]), toInt64((tuple)[1]));
|
||||||
}
|
}
|
||||||
return pair;
|
return pair;
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
#include "minddata/dataset/engine/opt/pass.h"
|
#include "minddata/dataset/engine/opt/pass.h"
|
||||||
#include "minddata/dataset/util/random.h"
|
#include "minddata/dataset/util/random.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -77,6 +78,8 @@ Status ValidateDatasetDirParam(const std::string &dataset_name, std::string data
|
||||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string real_path;
|
||||||
|
RETURN_IF_NOT_OK(Path::RealPath(dataset_dir, real_path));
|
||||||
Path dir(dataset_dir);
|
Path dir(dataset_dir);
|
||||||
if (!dir.IsDirectory()) {
|
if (!dir.IsDirectory()) {
|
||||||
std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
|
std::string err_msg = dataset_name + ": dataset_dir: [" + dataset_dir + "] is an invalid directory path.";
|
||||||
|
|
|
@ -325,6 +325,9 @@ class DataHelper {
|
||||||
return Status(kMDUnexpectedError, "Error opening Bin file to write");
|
return Status(kMDUnexpectedError, "Error opening Bin file to write");
|
||||||
}
|
}
|
||||||
size_t length = data.size();
|
size_t length = data.size();
|
||||||
|
if (length == 0) {
|
||||||
|
return Status(kMDUnexpectedError, "size of data is 0 when written into file.");
|
||||||
|
}
|
||||||
o.write(reinterpret_cast<const char *>(&data[0]), std::streamsize(length * sizeof(T)));
|
o.write(reinterpret_cast<const char *>(&data[0]), std::streamsize(length * sizeof(T)));
|
||||||
o.close();
|
o.close();
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,6 +64,9 @@ class Iterator {
|
||||||
/// \param[out] row The output tensor row.
|
/// \param[out] row The output tensor row.
|
||||||
/// \return Status error code, returns OK if no error encountered.
|
/// \return Status error code, returns OK if no error encountered.
|
||||||
Status GetNextRow(MSTensorMap *row) {
|
Status GetNextRow(MSTensorMap *row) {
|
||||||
|
if (row == nullptr) {
|
||||||
|
return Status(kMDUnexpectedError, "Got nullptr when GetNext row.");
|
||||||
|
}
|
||||||
MSTensorMapChar row_;
|
MSTensorMapChar row_;
|
||||||
row_.clear();
|
row_.clear();
|
||||||
row->clear();
|
row->clear();
|
||||||
|
|
|
@ -82,8 +82,8 @@ class DistributedSampler final : public Sampler {
|
||||||
/// \param[in] offset The starting position where access to elements in the dataset begins (default=-1).
|
/// \param[in] offset The starting position where access to elements in the dataset begins (default=-1).
|
||||||
/// \param[in] even_dist If true, each shard would return the same number of rows (default=true).
|
/// \param[in] even_dist If true, each shard would return the same number of rows (default=true).
|
||||||
/// If false the total rows returned by all the shards would not have overlap.
|
/// If false the total rows returned by all the shards would not have overlap.
|
||||||
explicit DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
|
DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle = true, int64_t num_samples = 0,
|
||||||
uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
|
uint32_t seed = 1, int64_t offset = -1, bool even_dist = true);
|
||||||
/// \brief Destructor.
|
/// \brief Destructor.
|
||||||
~DistributedSampler() = default;
|
~DistributedSampler() = default;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue