forked from mindspore-Ecosystem/mindspore
Fix test team issues
Signed-off-by: alex-yuyue <yue.yu1@huawei.com>
This commit is contained in:
parent
97b1c5eed5
commit
1850550189
|
@ -29,7 +29,7 @@ std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager();
|
|||
|
||||
// Function to set the seed to be used in any random generator
|
||||
bool set_seed(int32_t seed) {
|
||||
if (seed < 0 || seed > UINT32_MAX) {
|
||||
if (seed < 0 || seed > INT32_MAX) {
|
||||
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
|
||||
return false;
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ int32_t get_num_parallel_workers() { return _config->num_parallel_workers(); }
|
|||
|
||||
// Function to set the default interval (in milliseconds) for monitor sampling
|
||||
bool set_monitor_sampling_interval(int32_t interval) {
|
||||
if (interval <= 0 || interval > UINT32_MAX) {
|
||||
if (interval <= 0 || interval > INT32_MAX) {
|
||||
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
|
||||
return false;
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ int32_t get_monitor_sampling_interval() { return _config->monitor_sampling_inter
|
|||
|
||||
// Function to set the default timeout (in seconds) for DSWaitedCallback
|
||||
bool set_callback_timeback(int32_t timeout) {
|
||||
if (timeout <= 0 || timeout > UINT32_MAX) {
|
||||
if (timeout <= 0 || timeout > INT32_MAX) {
|
||||
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -539,7 +539,7 @@ ZipDataset::ZipDataset(const std::vector<std::shared_ptr<Dataset>> &datasets) {
|
|||
}
|
||||
#endif
|
||||
int64_t Dataset::GetBatchSize() {
|
||||
int64_t batch_size;
|
||||
int64_t batch_size = -1;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
||||
|
@ -548,7 +548,7 @@ int64_t Dataset::GetBatchSize() {
|
|||
}
|
||||
|
||||
int64_t Dataset::GetRepeatCount() {
|
||||
int64_t repeat_count;
|
||||
int64_t repeat_count = 0;
|
||||
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
||||
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
||||
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), 0);
|
||||
|
|
|
@ -227,7 +227,9 @@ Status SaveToDisk::Save() {
|
|||
nlohmann::json row_raw_data;
|
||||
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
|
||||
if (row.empty()) break;
|
||||
if (row.empty()) {
|
||||
break;
|
||||
}
|
||||
if (first_loop) {
|
||||
nlohmann::json mr_json;
|
||||
std::vector<std::string> index_fields;
|
||||
|
@ -249,7 +251,7 @@ Status SaveToDisk::Save() {
|
|||
raw_data.insert(
|
||||
std::pair<uint64_t, std::vector<nlohmann::json>>(mr_schema_id, std::vector<nlohmann::json>{row_raw_data}));
|
||||
std::vector<std::vector<uint8_t>> bin_data;
|
||||
if (nullptr != output_bin_data) {
|
||||
if (output_bin_data != nullptr) {
|
||||
bin_data.emplace_back(*output_bin_data);
|
||||
}
|
||||
mr_writer->WriteRawData(raw_data, bin_data);
|
||||
|
|
|
@ -60,7 +60,7 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
|
|||
}
|
||||
|
||||
// Populate the out map from the row and return it
|
||||
for (auto colMap : col_name_id_map_) {
|
||||
for (const auto colMap : col_name_id_map_) {
|
||||
(*out_map)[colMap.first] = std::move(curr_row[colMap.second]);
|
||||
}
|
||||
|
||||
|
@ -197,7 +197,7 @@ Status DatasetIterator::GetOutputShapes(std::vector<TensorShape> *out_shapes) {
|
|||
if (device_queue_row_.empty()) {
|
||||
RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
|
||||
}
|
||||
for (auto ts : device_queue_row_) {
|
||||
for (const auto ts : device_queue_row_) {
|
||||
out_shapes->push_back(ts->shape());
|
||||
}
|
||||
|
||||
|
@ -211,7 +211,7 @@ Status DatasetIterator::GetOutputTypes(std::vector<DataType> *out_types) {
|
|||
if (device_queue_row_.empty()) {
|
||||
RETURN_IF_NOT_OK(FetchNextTensorRow(&device_queue_row_));
|
||||
}
|
||||
for (auto ts : device_queue_row_) {
|
||||
for (const auto ts : device_queue_row_) {
|
||||
out_types->push_back(ts->type());
|
||||
}
|
||||
return Status::OK();
|
||||
|
|
|
@ -81,7 +81,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
|
|||
/// Constructor
|
||||
/// \param op_connector_size - The size for the output connector of this operator.
|
||||
/// \param sampler - The sampler for the op
|
||||
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);
|
||||
DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);
|
||||
|
||||
/// Destructor
|
||||
virtual ~DatasetOp() { tree_ = nullptr; }
|
||||
|
|
|
@ -26,7 +26,7 @@ namespace dataset {
|
|||
CpuMapJob::CpuMapJob() = default;
|
||||
|
||||
// Constructor
|
||||
CpuMapJob::CpuMapJob(std::vector<std::shared_ptr<TensorOp>> operations) : MapJob(operations) {}
|
||||
CpuMapJob::CpuMapJob(std::vector<std::shared_ptr<TensorOp>> operations) : MapJob(std::move(operations)) {}
|
||||
|
||||
// Destructor
|
||||
CpuMapJob::~CpuMapJob() = default;
|
||||
|
|
|
@ -19,9 +19,9 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/core/tensor_row.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
|
||||
#include "minddata/dataset/callback/callback_param.h"
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/constants.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
|
@ -44,7 +44,7 @@ MapOp::Builder::Builder() {
|
|||
Status MapOp::Builder::sanityCheck() const {
|
||||
if (build_tensor_funcs_.empty()) {
|
||||
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
|
||||
"Building a MapOp that has not provided any function/operation to apply");
|
||||
"Building a MapOp without providing any function/operation to apply");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -121,26 +121,13 @@ Status MapOp::GenerateWorkerJob(const std::unique_ptr<MapWorkerJob> *worker_job)
|
|||
// In the future, we will have heuristic or control from user to select target device
|
||||
MapTargetDevice target_device = MapTargetDevice::kCpu;
|
||||
|
||||
switch (target_device) {
|
||||
case MapTargetDevice::kCpu:
|
||||
// If there is no existing map_job, we will create one.
|
||||
// map_job could be nullptr when we are at the first tensor op or when the target device of the prev op
|
||||
// is different with that of the current op.
|
||||
if (map_job == nullptr) {
|
||||
map_job = std::make_shared<CpuMapJob>();
|
||||
}
|
||||
map_job->AddOperation(tfuncs_[i]);
|
||||
break;
|
||||
|
||||
case MapTargetDevice::kGpu:
|
||||
break;
|
||||
|
||||
case MapTargetDevice::kDvpp:
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
// If there is no existing map_job, we will create one.
|
||||
// map_job could be nullptr when we are at the first tensor op or when the target device of the prev op
|
||||
// is different with that of the current op.
|
||||
if (map_job == nullptr) {
|
||||
map_job = std::make_shared<CpuMapJob>();
|
||||
}
|
||||
map_job->AddOperation(tfuncs_[i]);
|
||||
|
||||
// Push map_job into worker_job if one of the two conditions is true:
|
||||
// 1) It is the last tensor operation in tfuncs_
|
||||
|
@ -364,7 +351,7 @@ Status MapOp::ComputeColMap() {
|
|||
// Validating if each of the input_columns exists in the DataBuffer.
|
||||
Status MapOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map) {
|
||||
for (const auto &inCol : in_columns_) {
|
||||
bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false;
|
||||
bool found = col_name_id_map.find(inCol) != col_name_id_map.end();
|
||||
if (!found) {
|
||||
std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns.";
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
AlbumOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr), builder_schema_file_("") {
|
||||
AlbumOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
|
||||
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
|
||||
builder_num_workers_ = cfg->num_parallel_workers();
|
||||
builder_rows_per_buffer_ = cfg->rows_per_buffer();
|
||||
|
@ -62,9 +62,8 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) {
|
|||
Status AlbumOp::Builder::SanityCheck() {
|
||||
Path dir(builder_dir_);
|
||||
std::string err_msg;
|
||||
err_msg += dir.IsDirectory() == false
|
||||
? "Invalid parameter, Album path is invalid or not set, path: " + builder_dir_ + ".\n"
|
||||
: "";
|
||||
err_msg +=
|
||||
!dir.IsDirectory() ? "Invalid parameter, Album path is invalid or not set, path: " + builder_dir_ + ".\n" : "";
|
||||
err_msg += builder_num_workers_ <= 0 ? "Invalid parameter, num_parallel_workers must be greater than 0, but got " +
|
||||
std::to_string(builder_num_workers_) + ".\n"
|
||||
: "";
|
||||
|
@ -97,8 +96,8 @@ bool StrComp(const std::string &a, const std::string &b) {
|
|||
// returns 1 if string "a" represent a numeric value less than string "b"
|
||||
// the following will always return name, provided there is only one "." character in name
|
||||
// "." character is guaranteed to exist since the extension is checked befor this function call.
|
||||
int64_t value_a = std::atoi(a.substr(1, a.find(".")).c_str());
|
||||
int64_t value_b = std::atoi(b.substr(1, b.find(".")).c_str());
|
||||
int64_t value_a = std::stoi(a.substr(1, a.find(".")).c_str());
|
||||
int64_t value_b = std::stoi(b.substr(1, b.find(".")).c_str());
|
||||
return value_a < value_b;
|
||||
}
|
||||
|
||||
|
@ -261,6 +260,7 @@ Status AlbumOp::LoadImageTensor(const std::string &image_file_path, uint32_t col
|
|||
RETURN_IF_NOT_OK(LoadEmptyTensor(col_num, row));
|
||||
return Status::OK();
|
||||
}
|
||||
fs.close();
|
||||
// Hack logic to replace png images with empty tensor
|
||||
Path file(image_file_path);
|
||||
std::set<std::string> png_ext = {".png", ".PNG"};
|
||||
|
@ -387,7 +387,7 @@ Status AlbumOp::LoadIDTensor(const std::string &file, uint32_t col_num, TensorRo
|
|||
return Status::OK();
|
||||
}
|
||||
// hack to get the file name without extension, the 1 is to get rid of the backslash character
|
||||
int64_t image_id = std::atoi(file.substr(1, file.find(".")).c_str());
|
||||
int64_t image_id = std::stoi(file.substr(1, file.find(".")).c_str());
|
||||
TensorPtr id;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateScalar<int64_t>(image_id, &id));
|
||||
MS_LOG(INFO) << "File ID " << image_id << ".";
|
||||
|
|
|
@ -16,16 +16,16 @@
|
|||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_ALBUM_OP_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <deque>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/data_buffer.h"
|
||||
#include "minddata/dataset/engine/data_schema.h"
|
||||
|
|
|
@ -55,7 +55,7 @@ Status DistributedSamplerRT::InitSampler() {
|
|||
if (offset_ != -1 || !even_dist_) {
|
||||
if (offset_ == -1) offset_ = 0;
|
||||
samples_per_buffer_ = (num_rows_ + offset_) / num_devices_;
|
||||
int remainder = (num_rows_ + offset_) % num_devices_;
|
||||
int64_t remainder = (num_rows_ + offset_) % num_devices_;
|
||||
if (device_id_ < remainder) samples_per_buffer_++;
|
||||
if (device_id_ < offset_) samples_per_buffer_--;
|
||||
} else {
|
||||
|
@ -63,7 +63,7 @@ Status DistributedSamplerRT::InitSampler() {
|
|||
samples_per_buffer_ = (num_rows_ + num_devices_ - 1) / num_devices_; // equals to ceil(num_rows/num_devices)
|
||||
}
|
||||
samples_per_buffer_ = num_samples_ < samples_per_buffer_ ? num_samples_ : samples_per_buffer_;
|
||||
if (shuffle_ == true) {
|
||||
if (shuffle_) {
|
||||
shuffle_vec_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
shuffle_vec_.push_back(i);
|
||||
|
|
|
@ -30,7 +30,7 @@ PKSamplerRT::PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle, int64_t
|
|||
Status PKSamplerRT::InitSampler() {
|
||||
labels_.reserve(label_to_ids_.size());
|
||||
for (const auto &pair : label_to_ids_) {
|
||||
if (pair.second.empty() == false) {
|
||||
if (!pair.second.empty()) {
|
||||
labels_.push_back(pair.first);
|
||||
}
|
||||
}
|
||||
|
@ -76,6 +76,7 @@ Status PKSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
|||
int64_t last_id = (samples_per_buffer_ + next_id_ > num_samples_) ? num_samples_ : samples_per_buffer_ + next_id_;
|
||||
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ids, last_id - next_id_));
|
||||
auto id_ptr = sample_ids->begin<int64_t>();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_class_ != 0, "samples cannot be zero.");
|
||||
while (next_id_ < last_id && id_ptr != sample_ids->end<int64_t>()) {
|
||||
int64_t cls_id = next_id_++ / samples_per_class_;
|
||||
const std::vector<int64_t> &samples = label_to_ids_[labels_[cls_id]];
|
||||
|
|
|
@ -32,8 +32,8 @@ class PKSamplerRT : public SamplerRT { // NOT YET FINISHED
|
|||
// @param int64_t val
|
||||
// @param bool shuffle - shuffle all classIds or not, if true, classes may be 5,1,4,3,2
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
PKSamplerRT(int64_t num_samples, int64_t val, bool shuffle,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// default destructor
|
||||
~PKSamplerRT() = default;
|
||||
|
|
|
@ -28,8 +28,8 @@ RandomSamplerRT::RandomSamplerRT(int64_t num_samples, bool replacement, bool res
|
|||
seed_(GetSeed()),
|
||||
replacement_(replacement),
|
||||
next_id_(0),
|
||||
reshuffle_each_epoch_(reshuffle_each_epoch),
|
||||
dist(nullptr) {}
|
||||
dist(nullptr),
|
||||
reshuffle_each_epoch_(reshuffle_each_epoch) {}
|
||||
|
||||
Status RandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (next_id_ > num_samples_) {
|
||||
|
@ -81,7 +81,7 @@ Status RandomSamplerRT::InitSampler() {
|
|||
samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_;
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false) {
|
||||
if (!replacement_) {
|
||||
shuffled_ids_.reserve(num_rows_);
|
||||
for (int64_t i = 0; i < num_rows_; i++) {
|
||||
shuffled_ids_.push_back(i);
|
||||
|
@ -104,7 +104,7 @@ Status RandomSamplerRT::ResetSampler() {
|
|||
|
||||
rnd_.seed(seed_);
|
||||
|
||||
if (replacement_ == false && reshuffle_each_epoch_) {
|
||||
if (!replacement_ && reshuffle_each_epoch_) {
|
||||
std::shuffle(shuffled_ids_.begin(), shuffled_ids_.end(), rnd_);
|
||||
}
|
||||
|
||||
|
|
|
@ -31,8 +31,8 @@ class RandomSamplerRT : public SamplerRT {
|
|||
// @param bool replacement - put he id back / or not after a sample
|
||||
// @param reshuffle_each_epoch - T/F to reshuffle after epoch
|
||||
// @param int64_t samples_per_buffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
RandomSamplerRT(int64_t num_samples, bool replacement, bool reshuffle_each_epoch,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
~RandomSamplerRT() = default;
|
||||
|
@ -50,7 +50,7 @@ class RandomSamplerRT : public SamplerRT {
|
|||
// @return - The error code return
|
||||
Status ResetSampler() override;
|
||||
|
||||
virtual void Print(std::ostream &out, bool show_all) const;
|
||||
void Print(std::ostream &out, bool show_all) const override;
|
||||
|
||||
private:
|
||||
uint32_t seed_;
|
||||
|
|
|
@ -27,7 +27,7 @@ Status RandomAccessOp::GetNumRowsInDataset(int64_t *num) const {
|
|||
// Here, it is just a getter method to return the value. However, it is invalid if there is
|
||||
// not a value set for this count, so generate a failure if that is the case.
|
||||
if (num == nullptr || num_rows_ == 0) {
|
||||
RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed it's num rows yet.");
|
||||
RETURN_STATUS_UNEXPECTED("RandomAccessOp has not computed its num rows yet.");
|
||||
}
|
||||
(*num) = num_rows_;
|
||||
return Status::OK();
|
||||
|
|
|
@ -60,7 +60,7 @@ class SamplerRT {
|
|||
// @param int64_t num_samples: the user-requested number of samples ids to generate. A value of 0
|
||||
// indicates that the sampler should produce the complete set of ids.
|
||||
// @param int64_t samplesPerBuffer: Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit SamplerRT(int64_t num_samples, int64_t samples_per_buffer);
|
||||
SamplerRT(int64_t num_samples, int64_t samples_per_buffer);
|
||||
|
||||
SamplerRT(const SamplerRT &s) : SamplerRT(s.num_samples_, s.samples_per_buffer_) {}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
SequentialSamplerRT::SequentialSamplerRT(int64_t num_samples, int64_t start_index, int64_t samples_per_buffer)
|
||||
: SamplerRT(num_samples, samples_per_buffer), start_index_(start_index), current_id_(start_index), id_count_(0) {}
|
||||
: SamplerRT(num_samples, samples_per_buffer), current_id_(start_index), start_index_(start_index), id_count_(0) {}
|
||||
|
||||
Status SequentialSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
|
||||
if (id_count_ > num_samples_) {
|
||||
|
|
|
@ -30,8 +30,8 @@ class SequentialSamplerRT : public SamplerRT {
|
|||
// full amount of ids from the dataset
|
||||
// @param start_index - The starting index value
|
||||
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
|
||||
explicit SequentialSamplerRT(int64_t num_samples, int64_t start_index,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
SequentialSamplerRT(int64_t num_samples, int64_t start_index,
|
||||
int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
~SequentialSamplerRT() = default;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <random>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ class SubsetRandomSamplerRT : public SamplerRT {
|
|||
// @param indices List of indices from where we will randomly draw samples.
|
||||
// @param samples_per_buffer The number of ids we draw on each call to GetNextBuffer().
|
||||
// When samplesPerBuffer=0, GetNextBuffer() will draw all the sample ids and return them at once.
|
||||
explicit SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
|
||||
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
SubsetRandomSamplerRT(int64_t num_samples, const std::vector<int64_t> &indices,
|
||||
std::int64_t samples_per_buffer = std::numeric_limits<int64_t>::max());
|
||||
|
||||
// Destructor.
|
||||
~SubsetRandomSamplerRT() = default;
|
||||
|
|
|
@ -42,9 +42,10 @@ Status WeightedRandomSamplerRT::InitSampler() {
|
|||
if (num_samples_ == 0 || num_samples_ > num_rows_) {
|
||||
num_samples_ = num_rows_;
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0 && num_samples_,
|
||||
"Invalid parameter, num_samples & num_rows must be greater than 0, but got num_rows: " +
|
||||
std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
num_rows_ > 0 && num_samples_,
|
||||
"Invalid parameter, num_samples and num_rows must be greater than 0, but got num_rows: " +
|
||||
std::to_string(num_rows_) + ", num_samples: " + std::to_string(num_samples_));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ > 0,
|
||||
"Invalid parameter, samples_per_buffer must be greater than 0, but got " +
|
||||
std::to_string(samples_per_buffer_) + ".\n");
|
||||
|
@ -57,7 +58,7 @@ Status WeightedRandomSamplerRT::InitSampler() {
|
|||
}
|
||||
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, "
|
||||
"Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
|
||||
"but got weight size: " +
|
||||
std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
|
@ -122,7 +123,7 @@ Status WeightedRandomSamplerRT::GetNextSample(std::unique_ptr<DataBuffer> *out_b
|
|||
|
||||
if (!replacement_ && (weights_.size() < static_cast<size_t>(num_samples_))) {
|
||||
RETURN_STATUS_UNEXPECTED(
|
||||
"Invalid parameter, without replacement, weights size must be greater than or equal to num_samples, "
|
||||
"Invalid parameter, without replacement, weight size must be greater than or equal to num_samples, "
|
||||
"but got weight size: " +
|
||||
std::to_string(weights_.size()) + ", num_samples: " + std::to_string(num_samples_));
|
||||
}
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
|
||||
#include <memory>
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/engine/datasetops/dataset_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
|
|
Loading…
Reference in New Issue