forked from mindspore-Ecosystem/mindspore
!15813 Convert some pipeline Ops to be inlined
From: @hfarahat Reviewed-by: Signed-off-by:
This commit is contained in:
commit
b86747b9a6
|
@ -37,22 +37,26 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
|
|||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false);
|
||||
}
|
||||
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_,
|
||||
children_start_end_index_);
|
||||
*ptr = std::make_shared<ConcatOp>(builder_sampler_, children_flag_and_nums_, children_start_end_index_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the ConcatOp.
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index)
|
||||
: PipelineOp(op_connector_size),
|
||||
children_num_(0),
|
||||
sampler_(sampler),
|
||||
children_flag_and_nums_(children_flag_and_nums),
|
||||
children_start_end_index_(children_start_end_index) {}
|
||||
: ConcatOp() {
|
||||
children_flag_and_nums_ = children_flag_and_nums;
|
||||
children_start_end_index_ = children_start_end_index;
|
||||
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler);
|
||||
if (distribute_sampler != nullptr) {
|
||||
num_shard_ = distribute_sampler->GetDeviceNum();
|
||||
shard_index_ = distribute_sampler->GetDeviceID();
|
||||
}
|
||||
}
|
||||
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {}
|
||||
ConcatOp::ConcatOp()
|
||||
: PipelineOp(0), cur_child_(0), verified_(false), num_shard_(1), shard_index_(0), sample_number_(0) {}
|
||||
|
||||
// A function that prints info about the Operator
|
||||
void ConcatOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
@ -65,98 +69,16 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDatasets: " << children_num_ << "\n\n";
|
||||
out << "\nDatasets: " << child_.size() << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// This definition is added to pass the cyclomatic complexity rule of <= 20 units
|
||||
// The NOLINT directive is to disable cpplint check.
|
||||
// Clang format and cpplint give conflicting recommendations on this line below.
|
||||
#define f(fv, sv, shard_index) \
|
||||
(((fv) == -1 && (sv) == -1) || ((fv) < (sv) && (shard_index) >= (fv) && (shard_index) < (sv)) || \
|
||||
((fv) > (sv) && ((shard_index) >= (fv) || (shard_index) < (sv)))) // NOLINT
|
||||
|
||||
// Main entry point for Concat
|
||||
Status ConcatOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
children_num_ = static_cast<int32_t>(child_.size());
|
||||
for (int32_t i = 0; i < children_num_; i++) {
|
||||
children_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
|
||||
}
|
||||
TensorRow new_row;
|
||||
int eof_count = 0;
|
||||
int sample_number = 0;
|
||||
bool is_not_mappable = true;
|
||||
bool is_not_mappable_or_second_ne_zero = true;
|
||||
int num_shard = 1;
|
||||
int shard_index = 0;
|
||||
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_);
|
||||
if (distribute_sampler != nullptr) {
|
||||
num_shard = distribute_sampler->GetDeviceNum();
|
||||
shard_index = distribute_sampler->GetDeviceID();
|
||||
}
|
||||
|
||||
while (eof_count == 0) {
|
||||
for (int i = 0; i < children_num_; i++) {
|
||||
// 1. Read the first row
|
||||
RETURN_IF_NOT_OK(children_iterators_[i]->FetchNextTensorRow(&new_row));
|
||||
if (new_row.eof()) {
|
||||
eof_count++;
|
||||
continue;
|
||||
}
|
||||
// 2. Do verification as for column name, column data type and rank of column data
|
||||
if (!new_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(Verify(i, new_row));
|
||||
}
|
||||
// 3. Put the data into output_connector
|
||||
if (!children_flag_and_nums_.empty()) {
|
||||
is_not_mappable = children_flag_and_nums_[i].first;
|
||||
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[i].second);
|
||||
}
|
||||
while (!new_row.eoe() && !new_row.eof()) {
|
||||
// if dataset is not mappable or generator dataset which source is yield, cannot get the number of samples in
|
||||
// python layer), we use filtering to get data
|
||||
if (sample_number % num_shard == shard_index && is_not_mappable_or_second_ne_zero) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
} else if (!is_not_mappable_or_second_ne_zero) {
|
||||
// if dataset is mappable or generator dataset which source is not yield,
|
||||
// get the start and end subscripts of valid values
|
||||
int fv = children_start_end_index_[i].first, sv = children_start_end_index_[i].second;
|
||||
|
||||
// determine whether the data allocated to the current shard id is false data
|
||||
if (f(fv, sv, shard_index)) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
}
|
||||
}
|
||||
|
||||
// if dataset is not mappable or generator dataset which source is yield, sample_number+=1
|
||||
if (is_not_mappable_or_second_ne_zero) {
|
||||
sample_number++;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(children_iterators_[i]->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
// if dataset is mappable,We don't use filtering to pick data.
|
||||
// so sample_number plus the length of the entire dataset
|
||||
if (!is_not_mappable_or_second_ne_zero) {
|
||||
sample_number += children_flag_and_nums_[i].second;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Add eoe row after get rows from all child
|
||||
if (eof_count == 0) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
|
||||
"Something went wrong, eof count does not match the number of children.");
|
||||
// 5. Add eof row in the end manually
|
||||
MS_LOG(DEBUG) << "Add the eof row manually in the end.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
#define f(fv, sv, shard_index) \
|
||||
((fv == -1 && sv == -1) || (fv < sv && shard_index >= fv && shard_index < sv) || \
|
||||
(fv > sv && (shard_index >= fv || shard_index < sv))) // NOLINT
|
||||
|
||||
Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) {
|
||||
if (id == 0) {
|
||||
|
@ -174,6 +96,7 @@ Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) {
|
|||
}
|
||||
}
|
||||
}
|
||||
verified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -211,6 +134,101 @@ Status ConcatOp::GetNumClasses(int64_t *num_classes) {
|
|||
*num_classes = max_num_classes;
|
||||
return Status::OK();
|
||||
}
|
||||
Status ConcatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
bool ConcatOp::IgnoreSample() {
|
||||
bool is_not_mappable_or_second_ne_zero = true;
|
||||
|
||||
if (!children_flag_and_nums_.empty()) {
|
||||
bool is_not_mappable = children_flag_and_nums_[cur_child_].first;
|
||||
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second);
|
||||
}
|
||||
bool ret = true;
|
||||
if (sample_number_ % num_shard_ == shard_index_ && is_not_mappable_or_second_ne_zero) {
|
||||
ret = false;
|
||||
} else if (!is_not_mappable_or_second_ne_zero) {
|
||||
// if dataset is mappable or generator dataset which source is not yield,
|
||||
// get the start and end subscripts of valid values
|
||||
int fv = children_start_end_index_[cur_child_].first, sv = children_start_end_index_[cur_child_].second;
|
||||
|
||||
// determine whether the data allocated to the current shard id is false data
|
||||
if (f(fv, sv, shard_index_)) {
|
||||
ret = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_not_mappable_or_second_ne_zero) {
|
||||
sample_number_++;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
bool is_not_mappable_or_second_ne_zero = true;
|
||||
|
||||
if (!children_flag_and_nums_.empty()) {
|
||||
bool is_not_mappable = children_flag_and_nums_[cur_child_].first;
|
||||
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second);
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_[cur_child_]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
|
||||
if (!row->eoe() && !row->eof()) {
|
||||
if (!verified_) RETURN_IF_NOT_OK(Verify(cur_child_, *row));
|
||||
|
||||
if (IgnoreSample()) {
|
||||
RETURN_IF_NOT_OK(GetNextRow(row, worker_id, retry_if_eoe));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
if (row->eoe()) {
|
||||
// if last child, send out eoe and reset epoch
|
||||
if (cur_child_ == child_.size() - 1) {
|
||||
// reset
|
||||
cur_child_ = 0;
|
||||
verified_ = false;
|
||||
UpdateRepeatAndEpochCounter();
|
||||
return Status::OK();
|
||||
}
|
||||
if (!is_not_mappable_or_second_ne_zero) {
|
||||
sample_number_ += children_flag_and_nums_[cur_child_].second;
|
||||
}
|
||||
cur_child_++;
|
||||
verified_ = false;
|
||||
RETURN_IF_NOT_OK(GetNextRow(row, worker_id, retry_if_eoe));
|
||||
return Status::OK();
|
||||
}
|
||||
if (row->eof()) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cur_child_ == 0, "Received an unexpected EOF.");
|
||||
for (int32_t i = cur_child_ + 1; i < child_.size(); i++) {
|
||||
RETURN_IF_NOT_OK(child_[i]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "Row must be an EOF.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t ConcatOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ConcatOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,9 +70,8 @@ class ConcatOp : public PipelineOp {
|
|||
// Constructor of the ConcatOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param op_connector_size - connector size
|
||||
explicit ConcatOp(int32_t op_connector_size);
|
||||
ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
ConcatOp();
|
||||
ConcatOp(const std::shared_ptr<SamplerRT> &sampler, const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index);
|
||||
|
||||
// Destructor
|
||||
|
@ -111,18 +110,29 @@ class ConcatOp : public PipelineOp {
|
|||
/// \return Status - The status code return
|
||||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
/// Check if the current sample will be taken or dropped
|
||||
/// \return bool
|
||||
bool IgnoreSample();
|
||||
|
||||
private:
|
||||
Status Verify(int32_t id, const TensorRow &tensor_row);
|
||||
|
||||
int32_t children_num_; // The num of child of parent node.
|
||||
std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name
|
||||
std::vector<DataType> data_type_;
|
||||
std::vector<dsize_t> data_rank_;
|
||||
std::shared_ptr<SamplerRT> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
|
||||
std::vector<std::unique_ptr<ChildIterator>> children_iterators_; // Iterator for fetching.
|
||||
int32_t cur_child_;
|
||||
bool verified_;
|
||||
int64_t sample_number_;
|
||||
|
||||
int32_t num_shard_;
|
||||
int32_t shard_index_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,21 +56,52 @@ Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) {
|
|||
|
||||
FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
|
||||
std::shared_ptr<TensorOp> predicate_func)
|
||||
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {}
|
||||
|
||||
Status FilterOp::operator()() {
|
||||
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {
|
||||
worker_queues_.Init(num_workers, op_queue_size);
|
||||
}
|
||||
Status FilterOp::LaunchThreadsAndInitOp() {
|
||||
// The operator class just starts off threads by calling the tree_ function.
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
filter_queues_.Init(num_workers_, oc_queue_size_);
|
||||
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
|
||||
Status rc =
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id());
|
||||
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->AllTasks()->CreateAsyncTask("FilterCollector", std::bind(&FilterOp::Collector, this), nullptr, id()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::operator()() {
|
||||
// Synchronize with TaskManager.
|
||||
Status rc = LaunchThreadsAndInitOp();
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
RETURN_IF_NOT_OK(Collector());
|
||||
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
int64_t cnt = 0;
|
||||
while (child_iterator_->eof_handled() == false) {
|
||||
while (new_row.empty() == false) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt % num_workers_]->EmplaceBack(new_row));
|
||||
cnt++;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOE))));
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOF))));
|
||||
// EOF received, send quit signal to all workers
|
||||
for (int32_t ind = 0; ind < num_workers_; ind++) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagQuit))));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -110,36 +141,30 @@ void FilterOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
|
||||
Status FilterOp::WorkerEntry(int32_t worker_id) {
|
||||
std::unique_ptr<ChildIterator> child_iterator = std::make_unique<ChildIterator>(this, worker_id, 0);
|
||||
|
||||
// Handshake with TaskManager that thread creation is successful.
|
||||
TaskManager::FindMe()->Post();
|
||||
bool worker_stop = false;
|
||||
while (worker_stop == false) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row));
|
||||
|
||||
while (!new_row.quit()) {
|
||||
// Getting a TensorRow to work on.
|
||||
TensorRow in_row;
|
||||
RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&in_row));
|
||||
if (new_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEoe)));
|
||||
} else if (new_row.eof()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEof)));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(ValidateInColumns(in_columns_));
|
||||
|
||||
if (in_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(in_row, filterCtrl::kFilterEoe)));
|
||||
continue;
|
||||
} else if (in_row.eof()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(in_row, filterCtrl::kFilterEof)));
|
||||
worker_stop = true;
|
||||
continue;
|
||||
bool result;
|
||||
RETURN_IF_NOT_OK(WorkerCompute(new_row, &result));
|
||||
|
||||
if (result)
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterFull)));
|
||||
else
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterEmpty)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateInColumns(in_columns_));
|
||||
|
||||
bool result;
|
||||
RETURN_IF_NOT_OK(WorkerCompute(in_row, &result));
|
||||
|
||||
if (result)
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_row), filterCtrl::kFilterFull)));
|
||||
else
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_row), filterCtrl::kFilterEmpty)));
|
||||
RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -160,15 +185,16 @@ Status FilterOp::WorkerCompute(const TensorRow &in_row, bool *out_predicate) {
|
|||
|
||||
// if the filtered TensorRow is written directly to out_connector_,
|
||||
// the thread fetching data will block in a queue.
|
||||
// Collector function will reorder the TensorRow in order.
|
||||
// Collector thread will reorder the TensorRow in order until EOF is received
|
||||
// for example in two work queues:
|
||||
// int filter_queues_:
|
||||
// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof)
|
||||
// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe)
|
||||
// queue1: TR(data1 kFilterEmpty) TR(eoe) TR(data4) TR(eof)
|
||||
// queue2: TR(data2) TR(data3 kFilterEmpty) TR(eoe)
|
||||
// after reorder in out_connector_:
|
||||
// queue1: DB(data2) DB(data4) DB(eof)
|
||||
// queue2: DB(eoe) DB(eoe)
|
||||
// queue1: TR(data2) TR(data4) TR(eof)
|
||||
// queue2: TR(eoe) TR(eoe)
|
||||
Status FilterOp::Collector() {
|
||||
TaskManager::FindMe()->Post();
|
||||
bool collector_stop = false;
|
||||
uint64_t task_id_cnt = 0;
|
||||
uint64_t out_id_cnt = 0;
|
||||
|
@ -216,6 +242,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
|||
|
||||
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
|
||||
}
|
||||
int32_t FilterOp::num_consumers() const { return 1; }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -126,6 +126,8 @@ class FilterOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kFilterOp; }
|
||||
|
||||
int32_t num_consumers() const override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
std::shared_ptr<TensorOp> predicate_func_;
|
||||
|
@ -136,6 +138,10 @@ class FilterOp : public ParallelOp {
|
|||
// Internal queue for filter.
|
||||
QueueList<std::pair<TensorRow, filterCtrl>> filter_queues_;
|
||||
|
||||
QueueList<TensorRow> worker_queues_; // internal queue for syncing worker
|
||||
|
||||
std::unique_ptr<ChildIterator> child_iterator_;
|
||||
|
||||
// Private function for worker/thread to loop continuously. It comprises the main
|
||||
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
|
||||
// applying predicate to each of the data, filter the data when predicate result is false.
|
||||
|
@ -168,6 +174,10 @@ class FilterOp : public ParallelOp {
|
|||
// @param input_columns The vector of input column names used in the current thread.
|
||||
// @return Status The status code returned
|
||||
Status ValidateInColumns(const std::vector<std::string> &input_columns);
|
||||
|
||||
// Do the initialization of all queues then start all worker threads
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -43,42 +43,28 @@ Status RenameOp::Builder::SanityCheck() const { return Status::OK(); }
|
|||
// build method for RenameOp
|
||||
Status RenameOp::Builder::Build(std::shared_ptr<RenameOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<RenameOp>(builder_in_columns_, builder_out_columns_, builder_op_connector_size_);
|
||||
*ptr = std::make_shared<RenameOp>(builder_in_columns_, builder_out_columns_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {}
|
||||
RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names)
|
||||
: PipelineOp(0), in_columns_(in_col_names), out_columns_(out_col_names) {}
|
||||
|
||||
// destructor
|
||||
RenameOp::~RenameOp() {}
|
||||
|
||||
// main entry point for rename
|
||||
Status RenameOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
|
||||
while (!new_row.eof()) {
|
||||
while (!new_row.eoe()) {
|
||||
MS_LOG(DEBUG) << "Rename operator pushing next row.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
MS_LOG(DEBUG) << "Rename operator EOE Received.";
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
MS_LOG(DEBUG) << "Rename operator fetching row after EOE.";
|
||||
// Gets a row from the child operator and projects the row.
|
||||
Status RenameOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
MS_LOG(DEBUG) << "Rename operator EOF Received.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenameOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RenameOp is an inlined operator."); }
|
||||
|
||||
// Rename core functionality to compute the new column name id map.
|
||||
// We need to overwrite the super class ComputeColMap here because we're making a modification of the
|
||||
// map from the child map.
|
||||
|
@ -151,15 +137,25 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t
|
|||
}
|
||||
}
|
||||
|
||||
Status RenameOp::EofReceived(int32_t) {
|
||||
MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now.";
|
||||
return Status::OK();
|
||||
int32_t RenameOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
Status RenameOp::EoeReceived(int32_t) {
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
int32_t RenameOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,17 +80,11 @@ class RenameOp : public PipelineOp {
|
|||
// @param in_col_names names of columns to rename
|
||||
// @param out_col_names names of columns after rename
|
||||
// @param op_connector_size connector size
|
||||
RenameOp(const std::vector<std::string> &in_col_names, // In: Col names to consume
|
||||
const std::vector<std::string> &out_col_names, // In: Col names to produce
|
||||
int32_t op_connector_size);
|
||||
RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names);
|
||||
|
||||
// Destructor
|
||||
~RenameOp();
|
||||
|
||||
Status EofReceived(int32_t) override;
|
||||
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// Print function for Rename
|
||||
// @param out output stream to print to
|
||||
// @param show_all if it should print everything
|
||||
|
@ -112,6 +106,13 @@ class RenameOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kRenameOp; }
|
||||
|
||||
// Gets a row from the child node and projects that row. The caller is typically our parent node.
|
||||
// @param row - output pointer to the projected row.
|
||||
// @param worker_id - The worker id
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
protected:
|
||||
// Rename core functionality
|
||||
// Computing the assignment of the new column name map.
|
||||
|
|
|
@ -43,13 +43,12 @@ Status SkipOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the SkipOp.
|
||||
SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
|
||||
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
|
||||
|
||||
// Destructor
|
||||
SkipOp::~SkipOp() {}
|
||||
|
@ -69,34 +68,48 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
// main entry point for skip
|
||||
Status SkipOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
while (!new_row.eof()) {
|
||||
// Reset count
|
||||
skip_count_ = 0;
|
||||
while (!new_row.eoe()) {
|
||||
// Drop first count rows
|
||||
if (skip_count_ < max_skips_) {
|
||||
skip_count_++;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
bool eoe_received = false;
|
||||
while (skip_count_ < max_skips_) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
eoe_received = true;
|
||||
break;
|
||||
}
|
||||
// we got eoe, now try again until we got eof
|
||||
MS_LOG(DEBUG) << "Skip operator EOE Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
skip_count_++;
|
||||
}
|
||||
if (!eoe_received) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
}
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
skip_count_ = 0;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Skip operator EOF Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t SkipOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t SkipOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,7 +50,7 @@ class SkipOp : public PipelineOp {
|
|||
// Constructor of the SkipOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param count - The number of skips to do
|
||||
explicit SkipOp(int32_t count, int32_t op_connector_size);
|
||||
explicit SkipOp(int32_t count);
|
||||
|
||||
// Destructor
|
||||
~SkipOp();
|
||||
|
@ -69,6 +69,9 @@ class SkipOp : public PipelineOp {
|
|||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return kSkipOp; }
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_skips_; // The number of skips that the user requested
|
||||
|
|
|
@ -43,13 +43,12 @@ Status TakeOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
|
||||
*ptr = std::make_shared<TakeOp>(build_max_takes_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the TakeOp.
|
||||
TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}
|
||||
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void TakeOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
@ -66,37 +65,53 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Main entry point for Take
|
||||
Status TakeOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
|
||||
while (!new_row.eof()) {
|
||||
while (!new_row.eoe()) {
|
||||
if (take_count_ < max_takes_) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
take_count_++;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
if (take_count_ == max_takes_) {
|
||||
RETURN_IF_NOT_OK(child_iterator_->Drain());
|
||||
break;
|
||||
}
|
||||
Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
bool eoe_received = false;
|
||||
if (take_count_ < max_takes_) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
eoe_received = true;
|
||||
} else {
|
||||
take_count_++;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (take_count_ == max_takes_) {
|
||||
// drain
|
||||
while (!row->eoe()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
}
|
||||
eoe_received = true;
|
||||
}
|
||||
if (eoe_received) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
take_count_ = 0;
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
take_count_ = 0;
|
||||
MS_LOG(DEBUG) << "Meet the end and push-back eof row.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t TakeOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t TakeOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,7 +53,7 @@ class TakeOp : public PipelineOp {
|
|||
// Constructor of the TakeOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param count - The number of takes to do
|
||||
explicit TakeOp(int32_t count, int32_t op_connector_size);
|
||||
explicit TakeOp(int32_t count);
|
||||
|
||||
// Destructor
|
||||
~TakeOp() = default;
|
||||
|
@ -82,6 +82,10 @@ class TakeOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kTakeOp; }
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_takes_; // The number of takes that the user requested
|
||||
int32_t take_count_; // A counter for the current number of executed takes
|
||||
|
|
|
@ -45,106 +45,40 @@ Status ZipOp::Builder::Build(std::shared_ptr<ZipOp> *ptr) {
|
|||
}
|
||||
|
||||
// Construct ZipOp here, local variables initialized in operator due to tree construction restrictions
|
||||
ZipOp::ZipOp(int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), children_num_(0), draining_(false), eof_(false) {}
|
||||
ZipOp::ZipOp(int32_t op_connector_size) : PipelineOp(0) {}
|
||||
|
||||
// destructor
|
||||
ZipOp::~ZipOp() {}
|
||||
|
||||
// Entry point for Zip, called by launch()
|
||||
Status ZipOp::operator()() {
|
||||
// The children_num_ parameter needs to be put here
|
||||
children_num_ = child_.size();
|
||||
// Synchronize with TaskManager once the thread is created.
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
// initialize the iterators
|
||||
for (int32_t i = 0; i < children_num_; ++i) {
|
||||
// magic number 0 since Zip is not a parallel Op
|
||||
child_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
|
||||
}
|
||||
|
||||
// Loop until eof is true
|
||||
while (!eof_) {
|
||||
// 1 Prepare new epoch
|
||||
RETURN_IF_NOT_OK(prepare());
|
||||
// 2 fetch first row
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(getNextTensorRow(&row));
|
||||
|
||||
// If an eof got picked up, then we're done
|
||||
if (eof_) {
|
||||
break;
|
||||
}
|
||||
while (!draining_) {
|
||||
// 3 send new row to the out connector
|
||||
MS_LOG(DEBUG) << "Zip operator finished one row, pushing, cols " << row.size() << ", map "
|
||||
<< column_name_id_map_.size() << ".";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
|
||||
// 4 fetch one more row
|
||||
RETURN_IF_NOT_OK(getNextTensorRow(&row));
|
||||
}
|
||||
// 5 handle drain state.
|
||||
if (draining_) {
|
||||
MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
|
||||
RETURN_IF_NOT_OK(drainPipeline());
|
||||
// Now that we have drained child inputs, send the eoe up.
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
}
|
||||
}
|
||||
|
||||
// 6 handle eof
|
||||
MS_LOG(DEBUG) << "Zip operator got EOF, propagating.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status ZipOp::prepare() {
|
||||
MS_LOG(DEBUG) << "Zip operator prepares for new epoch.";
|
||||
draining_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// fetches next zipped (merged) row
|
||||
Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) {
|
||||
Status ZipOp::getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, int32_t worker_id,
|
||||
bool retry_if_eoe) {
|
||||
*new_zip_row = {};
|
||||
// iterate over all iterators and generate a row
|
||||
for (int32_t i = 0; i < children_num_; ++i) {
|
||||
TensorRow new_row = {};
|
||||
RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row));
|
||||
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
|
||||
if (new_row.empty()) {
|
||||
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
|
||||
// to drain state.
|
||||
MS_LOG(DEBUG) << "Zip operator child iterator produced empty row.";
|
||||
draining_ = true;
|
||||
new_zip_row->clear();
|
||||
// If we picked up an eof here, then we are completely done.
|
||||
if ((child_iterators_[i])->eof_handled()) {
|
||||
MS_LOG(DEBUG) << "Zip operator iterator got EOF.";
|
||||
eof_ = true;
|
||||
}
|
||||
for (int32_t i = 0; i < child_.size(); ++i) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_[i]->GetNextRow(&new_row, worker_id, retry_if_eoe));
|
||||
if (new_row.eoe() || new_row.eof()) {
|
||||
*new_zip_row = new_row;
|
||||
*skip_child = i;
|
||||
return Status::OK();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << ".";
|
||||
// if row isn't empty then we can append the fetched row with new_zip_row
|
||||
new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end());
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << ".";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// drain end of epoch messages from iterator for this epoch
|
||||
Status ZipOp::drainPipeline() {
|
||||
// we don't need to drain if we reached eof
|
||||
if (eof_) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"ZipOp draining should not be done if already at eof!");
|
||||
}
|
||||
for (int32_t con = 0; con < children_num_; ++con) {
|
||||
Status ZipOp::drainPipeline(int32_t skip_child, int32_t worker_id, bool retry_if_eoe) {
|
||||
for (int32_t con = 0; con < child_.size(); ++con) {
|
||||
if (con == skip_child) continue;
|
||||
MS_LOG(DEBUG) << "Zip operator draining child at " << con << ".";
|
||||
RETURN_IF_NOT_OK(child_iterators_[con]->Drain());
|
||||
TensorRow row;
|
||||
while (!row.eoe()) {
|
||||
RETURN_IF_NOT_OK(child_[con]->GetNextRow(&row, worker_id, retry_if_eoe));
|
||||
}
|
||||
}
|
||||
// at this point all connectors don't contain end of epoch messages. next iteration should be clean
|
||||
return Status::OK();
|
||||
|
@ -161,9 +95,9 @@ void ZipOp::Print(std::ostream &out, // In: The output stream to print to
|
|||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDatasets: " << children_num_ << "\n\n";
|
||||
}
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDatasets: " << child_.size() << "\n\n";
|
||||
}
|
||||
|
||||
// overwrite function and handle eof
|
||||
|
@ -202,5 +136,39 @@ Status ZipOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ZipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
int32_t skip_child = -1;
|
||||
RETURN_IF_NOT_OK(getNextZippedRow(row, &skip_child, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
|
||||
RETURN_IF_NOT_OK(drainPipeline(skip_child, worker_id, retry_if_eoe));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t ZipOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ZipOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -104,15 +104,16 @@ class ZipOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kZipOp; }
|
||||
|
||||
private:
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status prepare();
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
// Special handle case where an empty row has been received from child iterator
|
||||
// @note - we need to drain eoe signals from all children connectors.
|
||||
// @details - when this function is called, then we encountered eoe at child iterator
|
||||
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
|
||||
Status drainPipeline();
|
||||
Status drainPipeline(int32_t skip_child, int32_t worker_id, bool retry_if_eoe);
|
||||
|
||||
// Merges 1 row from each childIterator together
|
||||
// @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
|
||||
|
@ -125,16 +126,11 @@ class ZipOp : public PipelineOp {
|
|||
// 1 a T
|
||||
// \ | /
|
||||
// 1, a, T
|
||||
Status getNextTensorRow(TensorRow *const new_zip_row);
|
||||
Status getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, int32_t worker_id, bool retry_if_eoe);
|
||||
|
||||
// Computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t children_num_;
|
||||
bool draining_;
|
||||
bool eof_;
|
||||
std::vector<std::unique_ptr<ChildIterator>> child_iterators_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -123,12 +123,11 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
std::shared_ptr<ConcatOp> op;
|
||||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
op = std::make_shared<ConcatOp>(connector_que_size_);
|
||||
op = std::make_shared<ConcatOp>();
|
||||
} else {
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
op =
|
||||
std::make_shared<ConcatOp>(connector_que_size_, sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
||||
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
||||
}
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
|
|
|
@ -58,7 +58,7 @@ Status RenameNode::ValidateParams() {
|
|||
}
|
||||
|
||||
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_);
|
||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -39,7 +39,7 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" +
|
|||
|
||||
// Function to build the SkipOp
|
||||
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<SkipOp>(skip_count_, connector_que_size_);
|
||||
auto op = std::make_shared<SkipOp>(skip_count_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -40,7 +40,7 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s
|
|||
|
||||
// Function to build the TakeOp
|
||||
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<TakeOp>(take_count_, connector_que_size_);
|
||||
auto op = std::make_shared<TakeOp>(take_count_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -44,7 +44,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// SkipOp
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2);
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5);
|
||||
rc = my_tree->AssociateNode(skip_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ def test_profiling_complex_pipeline():
|
|||
data = json.load(f)
|
||||
op_info = data["op_info"]
|
||||
assert len(op_info) == 5
|
||||
for i in range(5):
|
||||
for i in range(4):
|
||||
assert "size" in op_info[i]["metrics"]["output_queue"]
|
||||
assert "length" in op_info[i]["metrics"]["output_queue"]
|
||||
assert "throughput" in op_info[i]["metrics"]["output_queue"]
|
||||
|
|
Loading…
Reference in New Issue