!19886 [MD] Fix codecheck and pclint complaints master branch

Merge pull request !19886 from harshvardhangupta/codecheck_master
This commit is contained in:
i-robot 2021-07-15 12:50:28 +00:00 committed by Gitee
commit c72e4920f6
30 changed files with 48 additions and 49 deletions

View File

@ -130,12 +130,13 @@ Status PullIterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
return Status::OK();
}
Iterator::_Iterator::_Iterator(Iterator *lt) : lt_{lt}, cur_row_{nullptr} {
Iterator::_Iterator::_Iterator(Iterator *lt) : lt_{lt}, cur_row_{nullptr}, ind_{0} {
if (lt_) {
cur_row_ = new MSTensorMap();
Status rc = lt_->GetNextRow(cur_row_);
if (rc.IsError()) {
MS_LOG(ERROR) << "Error getting next row. Message: " << rc;
delete cur_row_;
cur_row_ = nullptr;
}
}

View File

@ -23,7 +23,7 @@ namespace dataset {
TensorRow::TensorRow() noexcept : id_(kDefaultRowId), path_({}), tensor_row_flag_(kFlagNone) {}
TensorRow::TensorRow(size_type n, TensorRow::value_type t) noexcept
TensorRow::TensorRow(size_type n, const TensorRow::value_type &t) noexcept
: id_(kDefaultRowId), path_({}), row_(n, t), tensor_row_flag_(kFlagNone) {}
TensorRow::TensorRow(const TensorRow::vector_type &v)

View File

@ -54,7 +54,7 @@ class TensorRow {
TensorRow() noexcept;
TensorRow(size_type n, value_type t) noexcept;
TensorRow(size_type n, const value_type &t) noexcept;
// Copy Constructors
explicit TensorRow(const vector_type &v);

View File

@ -140,8 +140,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
if (tracing_ != nullptr) {
cur_batch_num_++;
RETURN_IF_NOT_OK(tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_,
ProfilingTime::GetCurMilliSecond()));
RETURN_IF_NOT_OK(tracing_->Record(static_cast<int32_t>(CONNECTOR_DEPTH), cur_connector_capacity_, cur_batch_num_,
cur_connector_size_, ProfilingTime::GetCurMilliSecond()));
}
return Status::OK();
}

View File

@ -30,7 +30,7 @@ BarrierOp::BarrierOp(int32_t op_connector_size, const std::string &condition_nam
clean_up_(false),
eof_(false),
condition_name_(condition_name),
condition_function_(condition_func) {}
condition_function_(std::move(condition_func)) {}
// destructor
BarrierOp::~BarrierOp() {}

View File

@ -34,13 +34,13 @@ ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler,
children_start_end_index_ = children_start_end_index;
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler);
if (distribute_sampler != nullptr) {
num_shard_ = distribute_sampler->GetDeviceNum();
shard_index_ = distribute_sampler->GetDeviceID();
num_shard_ = static_cast<int32_t>(distribute_sampler->GetDeviceNum());
shard_index_ = static_cast<int32_t>(distribute_sampler->GetDeviceID());
}
}
ConcatOp::ConcatOp()
: PipelineOp(0), cur_child_(0), verified_(false), num_shard_(1), shard_index_(0), sample_number_(0) {}
: PipelineOp(0), cur_child_(0), verified_(false), sample_number_(0), num_shard_(1), shard_index_(0) {}
// A function that prints info about the Operator
void ConcatOp::Print(std::ostream &out, bool show_all) const {
@ -124,7 +124,7 @@ bool ConcatOp::IgnoreSample() {
bool is_not_mappable_or_second_ne_zero = true;
if (!children_flag_and_nums_.empty()) {
bool is_not_mappable = children_flag_and_nums_[cur_child_].first;
bool is_not_mappable = static_cast<bool>(children_flag_and_nums_[cur_child_].first);
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second);
}
bool ret = true;
@ -151,7 +151,7 @@ Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe
bool is_not_mappable_or_second_ne_zero = true;
if (!children_flag_and_nums_.empty()) {
bool is_not_mappable = children_flag_and_nums_[cur_child_].first;
bool is_not_mappable = static_cast<bool>(children_flag_and_nums_[cur_child_].first);
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second);
}
RETURN_IF_NOT_OK(child_[cur_child_]->GetNextRow(row, worker_id, retry_if_eoe));

View File

@ -130,7 +130,7 @@ Status FilterOp::WorkerEntry(int32_t worker_id) {
} else {
RETURN_IF_NOT_OK(ValidateInColumns(in_columns_));
bool result;
bool result = false;
RETURN_IF_NOT_OK(WorkerCompute(new_row, &result));
if (result)

View File

@ -33,7 +33,7 @@ AlbumOp::AlbumOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, boo
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_wkrs, queue_size, std::move(sampler)),
folder_path_(file_dir),
folder_path_(std::move(file_dir)),
decode_(do_decode),
extensions_(exts),
data_schema_(std::move(data_schema)),

View File

@ -182,7 +182,7 @@ class AlbumOp : public MappableLeafOp {
std::vector<std::string> image_rows_;
TensorPtr sample_ids_;
int32_t curr_row_;
uint32_t curr_row_;
};
} // namespace dataset
} // namespace mindspore

View File

@ -17,10 +17,10 @@
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include <fstream>
#include <iomanip>
#include <utility>
#include "debug/common.h"
#include "minddata/dataset/core/config_manager.h"
@ -37,7 +37,7 @@ ClueOp::ClueOp(int32_t num_workers, int64_t num_samples, int32_t worker_connecto
: NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
device_id),
clue_files_list_(std::move(clue_files_list)),
cols_to_keyword_(cols_to_keyword) {}
cols_to_keyword_(std::move(cols_to_keyword)) {}
Status ClueOp::Init() {
RETURN_IF_NOT_OK(filename_index_->insert(clue_files_list_));

View File

@ -58,8 +58,8 @@ CsvOp::CsvParser::CsvParser(int32_t worker_id, JaggedConnector *connector, char
: worker_id_(worker_id),
rows_connector_(connector),
csv_field_delim_(field_delim),
column_default_(column_default),
file_path_(file_path),
column_default_(std::move(column_default)),
file_path_(std::move(file_path)),
cur_state_(START_OF_FILE),
pos_(0),
cur_col_(0),
@ -629,7 +629,7 @@ Status CsvOp::CountAllFileRows(const std::vector<std::string> &files, bool csv_h
std::shared_ptr<CsvOp> op;
*count = 0;
if (!csv_header) {
column_name_list.emplace_back("");
(void)column_name_list.emplace_back("");
}
op = std::make_shared<CsvOp>(files, field_delim, column_list, column_name_list, num_workers, num_samples,
worker_connector_size, op_connector_size, shuffle_files, num_devices, device_id);

View File

@ -28,7 +28,7 @@ GeneratorOp::GeneratorOp(py::function generator_function, std::vector<std::strin
: PipelineOp(connector_size, std::move(sampler)),
generator_function_(generator_function),
column_names_(column_names),
column_types_(column_types),
column_types_(std::move(column_types)),
prefetch_size_(prefetch_size),
generator_counter_(0) {}

View File

@ -29,7 +29,7 @@ ImageFolderOp::ImageFolderOp(int32_t num_wkrs, std::string file_dir, int32_t que
const std::set<std::string> &exts, const std::map<std::string, int32_t> &map,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_wkrs, queue_size, std::move(sampler)),
folder_path_(file_dir),
folder_path_(std::move(file_dir)),
recursive_(recursive),
decode_(do_decode),
extensions_(exts),

View File

@ -38,7 +38,7 @@ ManifestOp::ManifestOp(int32_t num_works, std::string file, int32_t queue_size,
io_block_pushed_(0),
sampler_ind_(0),
data_schema_(std::move(data_schema)),
file_(file),
file_(std::move(file)),
class_index_(class_index),
decode_(decode),
usage_(usage) {

View File

@ -23,7 +23,6 @@
namespace mindspore {
namespace dataset {
MappableLeafOp::MappableLeafOp(int32_t num_wkrs, int32_t queue_size, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_wkrs, queue_size, std::move(sampler)) {}
@ -117,6 +116,5 @@ Status MappableLeafOp::WorkerEntry(int32_t worker_id) {
}
RETURN_STATUS_UNEXPECTED("[Internal ERROR] Unexpected nullptr received in worker.");
}
} // namespace dataset
} // namespace mindspore

View File

@ -47,7 +47,7 @@ MindRecordOp::MindRecordOp(int32_t num_mind_record_workers, std::vector<std::str
const ShuffleMode shuffle_mode, std::unique_ptr<ShardReader> shard_reader,
std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_mind_record_workers, op_connector_queue_size, std::move(sampler)),
dataset_file_(dataset_file),
dataset_file_(std::move(dataset_file)),
load_dataset_(load_dataset),
columns_to_load_(columns_to_load),
operators_(operators),
@ -197,7 +197,7 @@ Status MindRecordOp::WorkerEntry(int32_t worker_id) {
RETURN_STATUS_UNEXPECTED("Unexpected nullptr received in worker.");
}
Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, int64_t row_id, int32_t worker_id) {
Status MindRecordOp::GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, int32_t worker_id) {
*fetched_row = {};
auto rc = shard_reader_->GetNextById(row_id, worker_id);
auto task_type = rc.first;

View File

@ -118,7 +118,7 @@ class MindRecordOp : public MappableLeafOp {
std::string Name() const override { return "MindRecordOp"; }
private:
Status GetRowFromReader(TensorRow *fetched_row, int64_t row_id, int32_t worker_id);
Status GetRowFromReader(TensorRow *fetched_row, uint64_t row_id, int32_t worker_id);
/// Parses a single cell and puts the data into a tensor
/// @param tensor_row - the tensor row to put the parsed data in
@ -139,7 +139,6 @@ class MindRecordOp : public MappableLeafOp {
std::vector<std::string> columns_to_load_; // Columns to load from dataset
std::vector<std::shared_ptr<ShardOperator>> operators_; // ShardOperators to use
int32_t num_mind_record_workers_; // number of workers to be spawned by ShardReader
int32_t buffers_needed_; // Counter for the buffers that were fetched
std::atomic<int32_t> ended_worker_;
int64_t num_padded_;

View File

@ -32,11 +32,11 @@ const int32_t kMnistLabelFileMagicNumber = 2049;
const int32_t kMnistImageRows = 28;
const int32_t kMnistImageCols = 28;
MnistOp::MnistOp(const std::string &usage, int32_t num_workers, std::string folder_path, int32_t queue_size,
MnistOp::MnistOp(std::string usage, int32_t num_workers, std::string folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler)),
usage_(usage),
folder_path_(folder_path),
usage_(std::move(usage)),
folder_path_(std::move(folder_path)),
image_path_({}),
label_path_({}),
data_schema_(std::move(data_schema)) {

View File

@ -51,7 +51,7 @@ class MnistOp : public MappableLeafOp {
// @param int32_t queue_size - connector queue size
// @param std::unique_ptr<DataSchema> data_schema - the schema of the mnist dataset
// @param td::unique_ptr<Sampler> sampler - sampler tells MnistOp what to read
MnistOp(const std::string &usage, int32_t num_workers, std::string folder_path, int32_t queue_size,
MnistOp(std::string usage, int32_t num_workers, std::string folder_path, int32_t queue_size,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
// Destructor.

View File

@ -35,18 +35,17 @@
namespace mindspore {
namespace dataset {
NonMappableLeafOp::NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t total_num_rows,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices,
int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_devices),
filename_index_(std::make_unique<StringIndex>()),
load_io_block_queue_(true),
load_jagged_connector_(true),
total_rows_(total_num_rows),
filename_index_(std::make_unique<StringIndex>()),
finished_reading_dataset_(false),
total_rows_(total_num_rows),
load_io_block_queue_(true),
shuffle_files_(shuffle_files),
num_rows_per_shard_(0),
num_rows_(0) {
@ -286,6 +285,5 @@ Status NonMappableLeafOp::WaitToFillIOBlockQueue() {
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -21,7 +21,9 @@ namespace mindspore {
namespace dataset {
PythonSamplerRT::PythonSamplerRT(int64_t num_samples, py::object py_sampler_instance, int64_t samples_per_tensor)
: SamplerRT(num_samples, samples_per_tensor), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {}
: SamplerRT(num_samples, samples_per_tensor),
need_to_reset_(false),
py_sampler_instance(std::move(py_sampler_instance)) {}
Status PythonSamplerRT::GetNextSample(TensorRow *out) {
if (need_to_reset_) {

View File

@ -33,13 +33,13 @@ namespace dataset {
// Constructor for TransferNode
TransferNode::TransferNode(std::shared_ptr<DatasetNode> child, std::string queue_name, std::string device_type,
int32_t device_id, bool send_epoch_end, int32_t total_batch, bool create_data_info_queue)
: prefetch_size_(GlobalContext::config_manager()->prefetch_size()),
queue_name_(std::move(queue_name)),
: queue_name_(std::move(queue_name)),
device_id_(device_id),
device_type_(std::move(device_type)),
prefetch_size_(GlobalContext::config_manager()->prefetch_size()),
send_epoch_end_(send_epoch_end),
total_batch_(total_batch),
create_data_info_queue_(create_data_info_queue),
device_id_(device_id) {
create_data_info_queue_(create_data_info_queue) {
this->AddChild(child);
}

View File

@ -50,7 +50,8 @@ Status ConnectorThroughput::Sample() {
}
auto prev_out_rows_count = out_row_count_table_[col][out_row_count_table_.size() - 1];
if (dt != 0) {
auto thr = (cur_out_rows_count - prev_out_rows_count) / (1000 * dt);
const int32_t multiplier = 1000;
auto thr = (cur_out_rows_count - prev_out_rows_count) / (multiplier * dt);
throughput_row[col] = thr;
} else {
throughput_row[col] = 0;

View File

@ -35,7 +35,7 @@ Status PythonRuntimeContext::TerminateImpl() {
}
PythonRuntimeContext::~PythonRuntimeContext() {
Status rc = Terminate();
Status rc = PythonRuntimeContext::Terminate();
if (rc.IsError()) MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
{
py::gil_scoped_acquire gil_acquire;

View File

@ -36,7 +36,7 @@ Status NativeRuntimeContext::TerminateImpl() {
}
NativeRuntimeContext::~NativeRuntimeContext() {
Status rc = Terminate();
Status rc = NativeRuntimeContext::Terminate();
if (rc.IsError()) MS_LOG(ERROR) << "Error while terminating the consumer. Message:" << rc;
}

View File

@ -38,7 +38,7 @@
namespace mindspore {
namespace dataset {
TreeAdapter::TreeAdapter(UsageFlag usage) : usage_(usage), tree_state_(kCompileStateInit), launched_(false) {
TreeAdapter::TreeAdapter(UsageFlag usage) : usage_(usage), launched_(false), tree_state_(kCompileStateInit) {
optimize_ = common::GetEnv("OPTIMIZE") == "true";
// Initialize profiling parameters

View File

@ -512,7 +512,7 @@ Status MaskHelper(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Te
RETURN_IF_NOT_OK(value_tensor->GetItemAt(&value, {}));
auto in_itr = input->begin<T>();
auto out_itr = output->begin<bool>();
for (; in_itr != input->end<T>(); in_itr++, ++out_itr) {
for (; in_itr != input->end<T>(); ++in_itr, ++out_itr) {
switch (op) {
case RelationalOp::kEqual:
*out_itr = (*in_itr == value);

View File

@ -713,7 +713,6 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
// caller provided 1 mean/std value and there are more than one channel --> duplicate mean/std value
if (mean.size() == 1 && (*output)->shape()[CHANNEL_INDEX] != 1) {
std::vector<float> mean_t, std_t;
for (int64_t i = 0; i < (*output)->shape()[CHANNEL_INDEX] - 1; i++) {
mean.push_back(mean[0]);
std.push_back(std[0]);

View File

@ -32,7 +32,7 @@ Status TokenizerOp::Compute(const TensorRow &input, TensorRow *output) {
}
std::string_view str;
RETURN_IF_NOT_OK(input[0]->GetItemAt(&str, {}));
std::shared_ptr<Tensor> token_tensor, offsets_start_tensor, offsets_limit_tensor;
std::shared_ptr<Tensor> token_tensor;
std::vector<uint32_t> offsets_start, offsets_limit;
std::vector<std::string> splits;
RETURN_IF_NOT_OK(Tokenize(str, &splits, &offsets_start, &offsets_limit));

View File

@ -3982,6 +3982,7 @@ class GeneratorDataset(MappableDataset):
self.source_len = len(self.source)
self.max_rowsize = max_rowsize
self.sample_fn = None
def __deepcopy__(self, memodict):
if id(self) in memodict: