forked from mindspore-Ecosystem/mindspore
Modify concat to be an inlined Op
Modify zip to be an inlined Op Modify take to be an inlined Op Modify skip to be an inlined Op Modify rename to be an inlined Op Modify filter to consume child's rows using one thread
This commit is contained in:
parent
9416502e90
commit
e4c0bd51a5
|
@ -37,22 +37,26 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
|
|||
if (builder_sampler_ == nullptr) {
|
||||
builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false);
|
||||
}
|
||||
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_,
|
||||
children_start_end_index_);
|
||||
*ptr = std::make_shared<ConcatOp>(builder_sampler_, children_flag_and_nums_, children_start_end_index_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the ConcatOp.
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
ConcatOp::ConcatOp(const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index)
|
||||
: PipelineOp(op_connector_size),
|
||||
children_num_(0),
|
||||
sampler_(sampler),
|
||||
children_flag_and_nums_(children_flag_and_nums),
|
||||
children_start_end_index_(children_start_end_index) {}
|
||||
: ConcatOp() {
|
||||
children_flag_and_nums_ = children_flag_and_nums;
|
||||
children_start_end_index_ = children_start_end_index;
|
||||
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler);
|
||||
if (distribute_sampler != nullptr) {
|
||||
num_shard_ = distribute_sampler->GetDeviceNum();
|
||||
shard_index_ = distribute_sampler->GetDeviceID();
|
||||
}
|
||||
}
|
||||
|
||||
ConcatOp::ConcatOp(int32_t op_connector_size) : PipelineOp(op_connector_size), children_num_(0) {}
|
||||
ConcatOp::ConcatOp()
|
||||
: PipelineOp(0), cur_child_(0), verified_(false), num_shard_(1), shard_index_(0), sample_number_(0) {}
|
||||
|
||||
// A function that prints info about the Operator
|
||||
void ConcatOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
@ -65,98 +69,16 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
|
|||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDatasets: " << children_num_ << "\n\n";
|
||||
out << "\nDatasets: " << child_.size() << "\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// This definition is added to pass the cyclomatic complexity rule of <= 20 units
|
||||
// The NOLINT directive is to disable cpplint check.
|
||||
// Clang format and cpplint give conflicting recommendations on this line below.
|
||||
#define f(fv, sv, shard_index) \
|
||||
(((fv) == -1 && (sv) == -1) || ((fv) < (sv) && (shard_index) >= (fv) && (shard_index) < (sv)) || \
|
||||
((fv) > (sv) && ((shard_index) >= (fv) || (shard_index) < (sv)))) // NOLINT
|
||||
|
||||
// Main entry point for Concat
|
||||
Status ConcatOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
children_num_ = static_cast<int32_t>(child_.size());
|
||||
for (int32_t i = 0; i < children_num_; i++) {
|
||||
children_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
|
||||
}
|
||||
TensorRow new_row;
|
||||
int eof_count = 0;
|
||||
int sample_number = 0;
|
||||
bool is_not_mappable = true;
|
||||
bool is_not_mappable_or_second_ne_zero = true;
|
||||
int num_shard = 1;
|
||||
int shard_index = 0;
|
||||
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_);
|
||||
if (distribute_sampler != nullptr) {
|
||||
num_shard = distribute_sampler->GetDeviceNum();
|
||||
shard_index = distribute_sampler->GetDeviceID();
|
||||
}
|
||||
|
||||
while (eof_count == 0) {
|
||||
for (int i = 0; i < children_num_; i++) {
|
||||
// 1. Read the first row
|
||||
RETURN_IF_NOT_OK(children_iterators_[i]->FetchNextTensorRow(&new_row));
|
||||
if (new_row.eof()) {
|
||||
eof_count++;
|
||||
continue;
|
||||
}
|
||||
// 2. Do verification as for column name, column data type and rank of column data
|
||||
if (!new_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(Verify(i, new_row));
|
||||
}
|
||||
// 3. Put the data into output_connector
|
||||
if (!children_flag_and_nums_.empty()) {
|
||||
is_not_mappable = children_flag_and_nums_[i].first;
|
||||
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[i].second);
|
||||
}
|
||||
while (!new_row.eoe() && !new_row.eof()) {
|
||||
// if dataset is not mappable or generator dataset which source is yield, cannot get the number of samples in
|
||||
// python layer), we use filtering to get data
|
||||
if (sample_number % num_shard == shard_index && is_not_mappable_or_second_ne_zero) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
} else if (!is_not_mappable_or_second_ne_zero) {
|
||||
// if dataset is mappable or generator dataset which source is not yield,
|
||||
// get the start and end subscripts of valid values
|
||||
int fv = children_start_end_index_[i].first, sv = children_start_end_index_[i].second;
|
||||
|
||||
// determine whether the data allocated to the current shard id is false data
|
||||
if (f(fv, sv, shard_index)) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
}
|
||||
}
|
||||
|
||||
// if dataset is not mappable or generator dataset which source is yield, sample_number+=1
|
||||
if (is_not_mappable_or_second_ne_zero) {
|
||||
sample_number++;
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(children_iterators_[i]->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
// if dataset is mappable,We don't use filtering to pick data.
|
||||
// so sample_number plus the length of the entire dataset
|
||||
if (!is_not_mappable_or_second_ne_zero) {
|
||||
sample_number += children_flag_and_nums_[i].second;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Add eoe row after get rows from all child
|
||||
if (eof_count == 0) {
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
}
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
|
||||
"Something went wrong, eof count does not match the number of children.");
|
||||
// 5. Add eof row in the end manually
|
||||
MS_LOG(DEBUG) << "Add the eof row manually in the end.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
#define f(fv, sv, shard_index) \
|
||||
((fv == -1 && sv == -1) || (fv < sv && shard_index >= fv && shard_index < sv) || \
|
||||
(fv > sv && (shard_index >= fv || shard_index < sv))) // NOLINT
|
||||
|
||||
Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) {
|
||||
if (id == 0) {
|
||||
|
@ -174,6 +96,7 @@ Status ConcatOp::Verify(int32_t id, const TensorRow &new_row) {
|
|||
}
|
||||
}
|
||||
}
|
||||
verified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -211,6 +134,101 @@ Status ConcatOp::GetNumClasses(int64_t *num_classes) {
|
|||
*num_classes = max_num_classes;
|
||||
return Status::OK();
|
||||
}
|
||||
Status ConcatOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
bool ConcatOp::IgnoreSample() {
|
||||
bool is_not_mappable_or_second_ne_zero = true;
|
||||
|
||||
if (!children_flag_and_nums_.empty()) {
|
||||
bool is_not_mappable = children_flag_and_nums_[cur_child_].first;
|
||||
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second);
|
||||
}
|
||||
bool ret = true;
|
||||
if (sample_number_ % num_shard_ == shard_index_ && is_not_mappable_or_second_ne_zero) {
|
||||
ret = false;
|
||||
} else if (!is_not_mappable_or_second_ne_zero) {
|
||||
// if dataset is mappable or generator dataset which source is not yield,
|
||||
// get the start and end subscripts of valid values
|
||||
int fv = children_start_end_index_[cur_child_].first, sv = children_start_end_index_[cur_child_].second;
|
||||
|
||||
// determine whether the data allocated to the current shard id is false data
|
||||
if (f(fv, sv, shard_index_)) {
|
||||
ret = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_not_mappable_or_second_ne_zero) {
|
||||
sample_number_++;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status ConcatOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
bool is_not_mappable_or_second_ne_zero = true;
|
||||
|
||||
if (!children_flag_and_nums_.empty()) {
|
||||
bool is_not_mappable = children_flag_and_nums_[cur_child_].first;
|
||||
is_not_mappable_or_second_ne_zero = is_not_mappable || (!children_flag_and_nums_[cur_child_].second);
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_[cur_child_]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
|
||||
if (!row->eoe() && !row->eof()) {
|
||||
if (!verified_) RETURN_IF_NOT_OK(Verify(cur_child_, *row));
|
||||
|
||||
if (IgnoreSample()) {
|
||||
RETURN_IF_NOT_OK(GetNextRow(row, worker_id, retry_if_eoe));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
if (row->eoe()) {
|
||||
// if last child, send out eoe and reset epoch
|
||||
if (cur_child_ == child_.size() - 1) {
|
||||
// reset
|
||||
cur_child_ = 0;
|
||||
verified_ = false;
|
||||
UpdateRepeatAndEpochCounter();
|
||||
return Status::OK();
|
||||
}
|
||||
if (!is_not_mappable_or_second_ne_zero) {
|
||||
sample_number_ += children_flag_and_nums_[cur_child_].second;
|
||||
}
|
||||
cur_child_++;
|
||||
verified_ = false;
|
||||
RETURN_IF_NOT_OK(GetNextRow(row, worker_id, retry_if_eoe));
|
||||
return Status::OK();
|
||||
}
|
||||
if (row->eof()) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cur_child_ == 0, "Received an unexpected EOF.");
|
||||
for (int32_t i = cur_child_ + 1; i < child_.size(); i++) {
|
||||
RETURN_IF_NOT_OK(child_[i]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row->eof(), "Row must be an EOF.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t ConcatOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ConcatOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -70,9 +70,8 @@ class ConcatOp : public PipelineOp {
|
|||
// Constructor of the ConcatOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param op_connector_size - connector size
|
||||
explicit ConcatOp(int32_t op_connector_size);
|
||||
ConcatOp(int32_t op_connector_size, const std::shared_ptr<SamplerRT> &sampler,
|
||||
const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
ConcatOp();
|
||||
ConcatOp(const std::shared_ptr<SamplerRT> &sampler, const std::vector<std::pair<int, int>> &children_flag_and_nums,
|
||||
const std::vector<std::pair<int, int>> &children_start_end_index);
|
||||
|
||||
// Destructor
|
||||
|
@ -111,18 +110,29 @@ class ConcatOp : public PipelineOp {
|
|||
/// \return Status - The status code return
|
||||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
/// Check if the current sample will be taken or dropped
|
||||
/// \return bool
|
||||
bool IgnoreSample();
|
||||
|
||||
private:
|
||||
Status Verify(int32_t id, const TensorRow &tensor_row);
|
||||
|
||||
int32_t children_num_; // The num of child of parent node.
|
||||
std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name
|
||||
std::vector<DataType> data_type_;
|
||||
std::vector<dsize_t> data_rank_;
|
||||
std::shared_ptr<SamplerRT> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
|
||||
std::vector<std::unique_ptr<ChildIterator>> children_iterators_; // Iterator for fetching.
|
||||
int32_t cur_child_;
|
||||
bool verified_;
|
||||
int64_t sample_number_;
|
||||
|
||||
int32_t num_shard_;
|
||||
int32_t shard_index_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,21 +56,52 @@ Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) {
|
|||
|
||||
FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
|
||||
std::shared_ptr<TensorOp> predicate_func)
|
||||
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {}
|
||||
|
||||
Status FilterOp::operator()() {
|
||||
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {
|
||||
worker_queues_.Init(num_workers, op_queue_size);
|
||||
}
|
||||
Status FilterOp::LaunchThreadsAndInitOp() {
|
||||
// The operator class just starts off threads by calling the tree_ function.
|
||||
if (tree_ == nullptr) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "Pipeline init failed, Execution tree not set.");
|
||||
}
|
||||
filter_queues_.Init(num_workers_, oc_queue_size_);
|
||||
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
|
||||
Status rc =
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id());
|
||||
RETURN_IF_NOT_OK(worker_queues_.Register(tree_->AllTasks()));
|
||||
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1), Name(), id()));
|
||||
RETURN_IF_NOT_OK(
|
||||
tree_->AllTasks()->CreateAsyncTask("FilterCollector", std::bind(&FilterOp::Collector, this), nullptr, id()));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FilterOp::operator()() {
|
||||
// Synchronize with TaskManager.
|
||||
Status rc = LaunchThreadsAndInitOp();
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
RETURN_IF_NOT_OK(Collector());
|
||||
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
int64_t cnt = 0;
|
||||
while (child_iterator_->eof_handled() == false) {
|
||||
while (new_row.empty() == false) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt % num_workers_]->EmplaceBack(new_row));
|
||||
cnt++;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOE))));
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagEOF))));
|
||||
// EOF received, send quit signal to all workers
|
||||
for (int32_t ind = 0; ind < num_workers_; ind++) {
|
||||
RETURN_IF_NOT_OK(worker_queues_[cnt++ % num_workers_]->EmplaceBack(std::move(TensorRow(TensorRow::kFlagQuit))));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -110,36 +141,30 @@ void FilterOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
|
||||
Status FilterOp::WorkerEntry(int32_t worker_id) {
|
||||
std::unique_ptr<ChildIterator> child_iterator = std::make_unique<ChildIterator>(this, worker_id, 0);
|
||||
|
||||
// Handshake with TaskManager that thread creation is successful.
|
||||
TaskManager::FindMe()->Post();
|
||||
bool worker_stop = false;
|
||||
while (worker_stop == false) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row));
|
||||
|
||||
while (!new_row.quit()) {
|
||||
// Getting a TensorRow to work on.
|
||||
TensorRow in_row;
|
||||
RETURN_IF_NOT_OK(child_iterator->FetchNextTensorRow(&in_row));
|
||||
if (new_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEoe)));
|
||||
} else if (new_row.eof()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(new_row, filterCtrl::kFilterEof)));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(ValidateInColumns(in_columns_));
|
||||
|
||||
if (in_row.eoe()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(in_row, filterCtrl::kFilterEoe)));
|
||||
continue;
|
||||
} else if (in_row.eof()) {
|
||||
RETURN_IF_NOT_OK(filter_queues_[worker_id]->EmplaceBack(std::make_pair(in_row, filterCtrl::kFilterEof)));
|
||||
worker_stop = true;
|
||||
continue;
|
||||
bool result;
|
||||
RETURN_IF_NOT_OK(WorkerCompute(new_row, &result));
|
||||
|
||||
if (result)
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterFull)));
|
||||
else
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(new_row), filterCtrl::kFilterEmpty)));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateInColumns(in_columns_));
|
||||
|
||||
bool result;
|
||||
RETURN_IF_NOT_OK(WorkerCompute(in_row, &result));
|
||||
|
||||
if (result)
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_row), filterCtrl::kFilterFull)));
|
||||
else
|
||||
RETURN_IF_NOT_OK(
|
||||
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_row), filterCtrl::kFilterEmpty)));
|
||||
RETURN_IF_NOT_OK(worker_queues_[worker_id]->PopFront(&new_row));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -160,15 +185,16 @@ Status FilterOp::WorkerCompute(const TensorRow &in_row, bool *out_predicate) {
|
|||
|
||||
// if the filtered TensorRow is written directly to out_connector_,
|
||||
// the thread fetching data will block in a queue.
|
||||
// Collector function will reorder the TensorRow in order.
|
||||
// Collector thread will reorder the TensorRow in order until EOF is received
|
||||
// for example in two work queues:
|
||||
// int filter_queues_:
|
||||
// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof)
|
||||
// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe)
|
||||
// queue1: TR(data1 kFilterEmpty) TR(eoe) TR(data4) TR(eof)
|
||||
// queue2: TR(data2) TR(data3 kFilterEmpty) TR(eoe)
|
||||
// after reorder in out_connector_:
|
||||
// queue1: DB(data2) DB(data4) DB(eof)
|
||||
// queue2: DB(eoe) DB(eoe)
|
||||
// queue1: TR(data2) TR(data4) TR(eof)
|
||||
// queue2: TR(eoe) TR(eoe)
|
||||
Status FilterOp::Collector() {
|
||||
TaskManager::FindMe()->Post();
|
||||
bool collector_stop = false;
|
||||
uint64_t task_id_cnt = 0;
|
||||
uint64_t out_id_cnt = 0;
|
||||
|
@ -216,6 +242,7 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
|
|||
|
||||
return Status(StatusCode::kSuccess, "FilterOp predicate func call succeed");
|
||||
}
|
||||
int32_t FilterOp::num_consumers() const { return 1; }
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -126,6 +126,8 @@ class FilterOp : public ParallelOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kFilterOp; }
|
||||
|
||||
int32_t num_consumers() const override;
|
||||
|
||||
private:
|
||||
// predicate_func python callable which returns a boolean value.
|
||||
std::shared_ptr<TensorOp> predicate_func_;
|
||||
|
@ -136,6 +138,10 @@ class FilterOp : public ParallelOp {
|
|||
// Internal queue for filter.
|
||||
QueueList<std::pair<TensorRow, filterCtrl>> filter_queues_;
|
||||
|
||||
QueueList<TensorRow> worker_queues_; // internal queue for syncing worker
|
||||
|
||||
std::unique_ptr<ChildIterator> child_iterator_;
|
||||
|
||||
// Private function for worker/thread to loop continuously. It comprises the main
|
||||
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
|
||||
// applying predicate to each of the data, filter the data when predicate result is false.
|
||||
|
@ -168,6 +174,10 @@ class FilterOp : public ParallelOp {
|
|||
// @param input_columns The vector of input column names used in the current thread.
|
||||
// @return Status The status code returned
|
||||
Status ValidateInColumns(const std::vector<std::string> &input_columns);
|
||||
|
||||
// Do the initialization of all queues then start all worker threads
|
||||
// @return Status The status code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
|
|
|
@ -43,42 +43,28 @@ Status RenameOp::Builder::SanityCheck() const { return Status::OK(); }
|
|||
// build method for RenameOp
|
||||
Status RenameOp::Builder::Build(std::shared_ptr<RenameOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<RenameOp>(builder_in_columns_, builder_out_columns_, builder_op_connector_size_);
|
||||
*ptr = std::make_shared<RenameOp>(builder_in_columns_, builder_out_columns_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// constructor
|
||||
RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names,
|
||||
int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), in_columns_(in_col_names), out_columns_(out_col_names) {}
|
||||
RenameOp::RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names)
|
||||
: PipelineOp(0), in_columns_(in_col_names), out_columns_(out_col_names) {}
|
||||
|
||||
// destructor
|
||||
RenameOp::~RenameOp() {}
|
||||
|
||||
// main entry point for rename
|
||||
Status RenameOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
|
||||
while (!new_row.eof()) {
|
||||
while (!new_row.eoe()) {
|
||||
MS_LOG(DEBUG) << "Rename operator pushing next row.";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
MS_LOG(DEBUG) << "Rename operator EOE Received.";
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
MS_LOG(DEBUG) << "Rename operator fetching row after EOE.";
|
||||
// Gets a row from the child operator and projects the row.
|
||||
Status RenameOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
}
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
MS_LOG(DEBUG) << "Rename operator EOF Received.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RenameOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. RenameOp is an inlined operator."); }
|
||||
|
||||
// Rename core functionality to compute the new column name id map.
|
||||
// We need to overwrite the super class ComputeColMap here because we're making a modification of the
|
||||
// map from the child map.
|
||||
|
@ -151,15 +137,25 @@ void RenameOp::Print(std::ostream &out, // In: The output stream to print t
|
|||
}
|
||||
}
|
||||
|
||||
Status RenameOp::EofReceived(int32_t) {
|
||||
MS_LOG(DEBUG) << "Rename operator EOF received, do nothing now.";
|
||||
return Status::OK();
|
||||
int32_t RenameOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Rename operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Rename operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
Status RenameOp::EoeReceived(int32_t) {
|
||||
state_ = OpState::kDeOpIdle;
|
||||
return Status::OK();
|
||||
int32_t RenameOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Rename operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -80,17 +80,11 @@ class RenameOp : public PipelineOp {
|
|||
// @param in_col_names names of columns to rename
|
||||
// @param out_col_names names of columns after rename
|
||||
// @param op_connector_size connector size
|
||||
RenameOp(const std::vector<std::string> &in_col_names, // In: Col names to consume
|
||||
const std::vector<std::string> &out_col_names, // In: Col names to produce
|
||||
int32_t op_connector_size);
|
||||
RenameOp(const std::vector<std::string> &in_col_names, const std::vector<std::string> &out_col_names);
|
||||
|
||||
// Destructor
|
||||
~RenameOp();
|
||||
|
||||
Status EofReceived(int32_t) override;
|
||||
|
||||
Status EoeReceived(int32_t) override;
|
||||
|
||||
// Print function for Rename
|
||||
// @param out output stream to print to
|
||||
// @param show_all if it should print everything
|
||||
|
@ -112,6 +106,13 @@ class RenameOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kRenameOp; }
|
||||
|
||||
// Gets a row from the child node and projects that row. The caller is typically our parent node.
|
||||
// @param row - output pointer to the projected row.
|
||||
// @param worker_id - The worker id
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
protected:
|
||||
// Rename core functionality
|
||||
// Computing the assignment of the new column name map.
|
||||
|
|
|
@ -43,13 +43,12 @@ Status SkipOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status SkipOp::Builder::Build(std::shared_ptr<SkipOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_, builder_op_connector_size_);
|
||||
*ptr = std::make_shared<SkipOp>(build_max_skips_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the SkipOp.
|
||||
SkipOp::SkipOp(int32_t count, int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), max_skips_(count), skip_count_(0) {}
|
||||
SkipOp::SkipOp(int32_t count) : PipelineOp(0), max_skips_(count), skip_count_(0) {}
|
||||
|
||||
// Destructor
|
||||
SkipOp::~SkipOp() {}
|
||||
|
@ -69,34 +68,48 @@ void SkipOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
// main entry point for skip
|
||||
Status SkipOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
while (!new_row.eof()) {
|
||||
// Reset count
|
||||
skip_count_ = 0;
|
||||
while (!new_row.eoe()) {
|
||||
// Drop first count rows
|
||||
if (skip_count_ < max_skips_) {
|
||||
skip_count_++;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
}
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
Status SkipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
bool eoe_received = false;
|
||||
while (skip_count_ < max_skips_) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
eoe_received = true;
|
||||
break;
|
||||
}
|
||||
// we got eoe, now try again until we got eof
|
||||
MS_LOG(DEBUG) << "Skip operator EOE Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
skip_count_++;
|
||||
}
|
||||
if (!eoe_received) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
}
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
skip_count_ = 0;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Skip operator EOF Received.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t SkipOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t SkipOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -50,7 +50,7 @@ class SkipOp : public PipelineOp {
|
|||
// Constructor of the SkipOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param count - The number of skips to do
|
||||
explicit SkipOp(int32_t count, int32_t op_connector_size);
|
||||
explicit SkipOp(int32_t count);
|
||||
|
||||
// Destructor
|
||||
~SkipOp();
|
||||
|
@ -69,6 +69,9 @@ class SkipOp : public PipelineOp {
|
|||
// Op name getter
|
||||
// @return Name of the current Op
|
||||
std::string Name() const override { return kSkipOp; }
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_skips_; // The number of skips that the user requested
|
||||
|
|
|
@ -43,13 +43,12 @@ Status TakeOp::Builder::SanityCheck() const {
|
|||
// The builder "build" method creates the final object.
|
||||
Status TakeOp::Builder::Build(std::shared_ptr<TakeOp> *ptr) {
|
||||
RETURN_IF_NOT_OK(SanityCheck());
|
||||
*ptr = std::make_shared<TakeOp>(build_max_takes_, builder_op_connector_size_);
|
||||
*ptr = std::make_shared<TakeOp>(build_max_takes_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Constructor of the TakeOp.
|
||||
TakeOp::TakeOp(int32_t count, int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), max_takes_(count), take_count_(0) {}
|
||||
TakeOp::TakeOp(int32_t count) : PipelineOp(0), max_takes_(count), take_count_(0) {}
|
||||
|
||||
// A print method typically used for debugging
|
||||
void TakeOp::Print(std::ostream &out, bool show_all) const {
|
||||
|
@ -66,37 +65,53 @@ void TakeOp::Print(std::ostream &out, bool show_all) const {
|
|||
}
|
||||
}
|
||||
|
||||
// Main entry point for Take
|
||||
Status TakeOp::operator()() {
|
||||
TaskManager::FindMe()->Post();
|
||||
child_iterator_ = std::make_unique<ChildIterator>(this, 0, 0);
|
||||
Status TakeOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
|
||||
while (!new_row.eof()) {
|
||||
while (!new_row.eoe()) {
|
||||
if (take_count_ < max_takes_) {
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row)));
|
||||
take_count_++;
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
if (take_count_ == max_takes_) {
|
||||
RETURN_IF_NOT_OK(child_iterator_->Drain());
|
||||
break;
|
||||
}
|
||||
Status TakeOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
bool eoe_received = false;
|
||||
if (take_count_ < max_takes_) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
eoe_received = true;
|
||||
} else {
|
||||
take_count_++;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (take_count_ == max_takes_) {
|
||||
// drain
|
||||
while (!row->eoe()) {
|
||||
RETURN_IF_NOT_OK(child_[0]->GetNextRow(row, worker_id, retry_if_eoe));
|
||||
}
|
||||
eoe_received = true;
|
||||
}
|
||||
if (eoe_received) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
take_count_ = 0;
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&new_row));
|
||||
}
|
||||
|
||||
take_count_ = 0;
|
||||
MS_LOG(DEBUG) << "Meet the end and push-back eof row.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t TakeOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t TakeOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -53,7 +53,7 @@ class TakeOp : public PipelineOp {
|
|||
// Constructor of the TakeOp.
|
||||
// @note The builder class should be used to call it
|
||||
// @param count - The number of takes to do
|
||||
explicit TakeOp(int32_t count, int32_t op_connector_size);
|
||||
explicit TakeOp(int32_t count);
|
||||
|
||||
// Destructor
|
||||
~TakeOp() = default;
|
||||
|
@ -82,6 +82,10 @@ class TakeOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kTakeOp; }
|
||||
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
int32_t max_takes_; // The number of takes that the user requested
|
||||
int32_t take_count_; // A counter for the current number of executed takes
|
||||
|
|
|
@ -45,106 +45,40 @@ Status ZipOp::Builder::Build(std::shared_ptr<ZipOp> *ptr) {
|
|||
}
|
||||
|
||||
// Construct ZipOp here, local variables initialized in operator due to tree construction restrictions
|
||||
ZipOp::ZipOp(int32_t op_connector_size)
|
||||
: PipelineOp(op_connector_size), children_num_(0), draining_(false), eof_(false) {}
|
||||
ZipOp::ZipOp(int32_t op_connector_size) : PipelineOp(0) {}
|
||||
|
||||
// destructor
|
||||
ZipOp::~ZipOp() {}
|
||||
|
||||
// Entry point for Zip, called by launch()
|
||||
Status ZipOp::operator()() {
|
||||
// The children_num_ parameter needs to be put here
|
||||
children_num_ = child_.size();
|
||||
// Synchronize with TaskManager once the thread is created.
|
||||
TaskManager::FindMe()->Post();
|
||||
|
||||
// initialize the iterators
|
||||
for (int32_t i = 0; i < children_num_; ++i) {
|
||||
// magic number 0 since Zip is not a parallel Op
|
||||
child_iterators_.push_back(std::make_unique<ChildIterator>(this, 0, i));
|
||||
}
|
||||
|
||||
// Loop until eof is true
|
||||
while (!eof_) {
|
||||
// 1 Prepare new epoch
|
||||
RETURN_IF_NOT_OK(prepare());
|
||||
// 2 fetch first row
|
||||
TensorRow row;
|
||||
RETURN_IF_NOT_OK(getNextTensorRow(&row));
|
||||
|
||||
// If an eof got picked up, then we're done
|
||||
if (eof_) {
|
||||
break;
|
||||
}
|
||||
while (!draining_) {
|
||||
// 3 send new row to the out connector
|
||||
MS_LOG(DEBUG) << "Zip operator finished one row, pushing, cols " << row.size() << ", map "
|
||||
<< column_name_id_map_.size() << ".";
|
||||
RETURN_IF_NOT_OK(out_connector_->Add(std::move(row)));
|
||||
// 4 fetch one more row
|
||||
RETURN_IF_NOT_OK(getNextTensorRow(&row));
|
||||
}
|
||||
// 5 handle drain state.
|
||||
if (draining_) {
|
||||
MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
|
||||
RETURN_IF_NOT_OK(drainPipeline());
|
||||
// Now that we have drained child inputs, send the eoe up.
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOE());
|
||||
}
|
||||
}
|
||||
|
||||
// 6 handle eof
|
||||
MS_LOG(DEBUG) << "Zip operator got EOF, propagating.";
|
||||
RETURN_IF_NOT_OK(out_connector_->SendEOF());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status ZipOp::prepare() {
|
||||
MS_LOG(DEBUG) << "Zip operator prepares for new epoch.";
|
||||
draining_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// fetches next zipped (merged) row
|
||||
Status ZipOp::getNextTensorRow(TensorRow *const new_zip_row) {
|
||||
Status ZipOp::getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, int32_t worker_id,
|
||||
bool retry_if_eoe) {
|
||||
*new_zip_row = {};
|
||||
// iterate over all iterators and generate a row
|
||||
for (int32_t i = 0; i < children_num_; ++i) {
|
||||
TensorRow new_row = {};
|
||||
RETURN_IF_NOT_OK((child_iterators_[i])->FetchNextTensorRow(&new_row));
|
||||
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
|
||||
if (new_row.empty()) {
|
||||
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
|
||||
// to drain state.
|
||||
MS_LOG(DEBUG) << "Zip operator child iterator produced empty row.";
|
||||
draining_ = true;
|
||||
new_zip_row->clear();
|
||||
// If we picked up an eof here, then we are completely done.
|
||||
if ((child_iterators_[i])->eof_handled()) {
|
||||
MS_LOG(DEBUG) << "Zip operator iterator got EOF.";
|
||||
eof_ = true;
|
||||
}
|
||||
for (int32_t i = 0; i < child_.size(); ++i) {
|
||||
TensorRow new_row;
|
||||
RETURN_IF_NOT_OK(child_[i]->GetNextRow(&new_row, worker_id, retry_if_eoe));
|
||||
if (new_row.eoe() || new_row.eof()) {
|
||||
*new_zip_row = new_row;
|
||||
*skip_child = i;
|
||||
return Status::OK();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Zip operator got row from child " << i << ". Num cols: " << new_row.size() << ".";
|
||||
// if row isn't empty then we can append the fetched row with new_zip_row
|
||||
new_zip_row->insert(new_zip_row->end(), new_row.begin(), new_row.end());
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "Zip operator builds a zipped row. Number of columns in row: " << new_zip_row->size() << ".";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// drain end of epoch messages from iterator for this epoch
|
||||
Status ZipOp::drainPipeline() {
|
||||
// we don't need to drain if we reached eof
|
||||
if (eof_) {
|
||||
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
|
||||
"ZipOp draining should not be done if already at eof!");
|
||||
}
|
||||
for (int32_t con = 0; con < children_num_; ++con) {
|
||||
Status ZipOp::drainPipeline(int32_t skip_child, int32_t worker_id, bool retry_if_eoe) {
|
||||
for (int32_t con = 0; con < child_.size(); ++con) {
|
||||
if (con == skip_child) continue;
|
||||
MS_LOG(DEBUG) << "Zip operator draining child at " << con << ".";
|
||||
RETURN_IF_NOT_OK(child_iterators_[con]->Drain());
|
||||
TensorRow row;
|
||||
while (!row.eoe()) {
|
||||
RETURN_IF_NOT_OK(child_[con]->GetNextRow(&row, worker_id, retry_if_eoe));
|
||||
}
|
||||
}
|
||||
// at this point all connectors don't contain end of epoch messages. next iteration should be clean
|
||||
return Status::OK();
|
||||
|
@ -161,9 +95,9 @@ void ZipOp::Print(std::ostream &out, // In: The output stream to print to
|
|||
} else {
|
||||
// Call the super class for displaying any common detailed info
|
||||
PipelineOp::Print(out, show_all);
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDatasets: " << children_num_ << "\n\n";
|
||||
}
|
||||
// Then show any custom derived-internal stuff
|
||||
out << "\nDatasets: " << child_.size() << "\n\n";
|
||||
}
|
||||
|
||||
// overwrite function and handle eof
|
||||
|
@ -202,5 +136,39 @@ Status ZipOp::ComputeColMap() {
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ZipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is an inlined operator."); }
|
||||
|
||||
Status ZipOp::GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) {
|
||||
int32_t skip_child = -1;
|
||||
RETURN_IF_NOT_OK(getNextZippedRow(row, &skip_child, worker_id, retry_if_eoe));
|
||||
if (row->eoe()) {
|
||||
UpdateRepeatAndEpochCounter();
|
||||
MS_LOG(DEBUG) << "Zip operator is now draining child inputs.";
|
||||
RETURN_IF_NOT_OK(drainPipeline(skip_child, worker_id, retry_if_eoe));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int32_t ZipOp::num_consumers() const {
|
||||
if (parent_.empty()) {
|
||||
MS_LOG(DEBUG) << "Return operator, no parent node, assuming it's the root and returning 1.";
|
||||
return 1;
|
||||
} else if (parent_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to the first parent is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return parent_[0]->num_consumers();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t ZipOp::num_producers() const {
|
||||
if (child_.empty() || child_[0] == nullptr) {
|
||||
MS_LOG(DEBUG) << "Return operator, pointer to child node is null. Returning 0.";
|
||||
return 0;
|
||||
} else {
|
||||
return child_[0]->num_producers();
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -104,15 +104,16 @@ class ZipOp : public PipelineOp {
|
|||
// @return Name of the current Op
|
||||
std::string Name() const override { return kZipOp; }
|
||||
|
||||
private:
|
||||
// Handles preprocessing of the main loop, used when starting new epoch
|
||||
Status prepare();
|
||||
Status GetNextRow(TensorRow *row, int32_t worker_id, bool retry_if_eoe) override;
|
||||
int32_t num_consumers() const override;
|
||||
int32_t num_producers() const override;
|
||||
|
||||
private:
|
||||
// Special handle case where an empty row has been received from child iterator
|
||||
// @note - we need to drain eoe signals from all children connectors.
|
||||
// @details - when this function is called, then we encountered eoe at child iterator
|
||||
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
|
||||
Status drainPipeline();
|
||||
Status drainPipeline(int32_t skip_child, int32_t worker_id, bool retry_if_eoe);
|
||||
|
||||
// Merges 1 row from each childIterator together
|
||||
// @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
|
||||
|
@ -125,16 +126,11 @@ class ZipOp : public PipelineOp {
|
|||
// 1 a T
|
||||
// \ | /
|
||||
// 1, a, T
|
||||
Status getNextTensorRow(TensorRow *const new_zip_row);
|
||||
Status getNextZippedRow(TensorRow *const new_zip_row, int32_t *skip_child, int32_t worker_id, bool retry_if_eoe);
|
||||
|
||||
// Computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t children_num_;
|
||||
bool draining_;
|
||||
bool eof_;
|
||||
std::vector<std::unique_ptr<ChildIterator>> child_iterators_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -123,12 +123,11 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
|
|||
Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
std::shared_ptr<ConcatOp> op;
|
||||
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
|
||||
op = std::make_shared<ConcatOp>(connector_que_size_);
|
||||
op = std::make_shared<ConcatOp>();
|
||||
} else {
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
op =
|
||||
std::make_shared<ConcatOp>(connector_que_size_, sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
||||
op = std::make_shared<ConcatOp>(sampler_rt, children_flag_and_nums_, children_start_end_index_);
|
||||
}
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
|
|
|
@ -58,7 +58,7 @@ Status RenameNode::ValidateParams() {
|
|||
}
|
||||
|
||||
Status RenameNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_, connector_que_size_);
|
||||
auto op = std::make_shared<RenameOp>(input_columns_, output_columns_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -39,7 +39,7 @@ void SkipNode::Print(std::ostream &out) const { out << Name() + "(skip_count:" +
|
|||
|
||||
// Function to build the SkipOp
|
||||
Status SkipNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<SkipOp>(skip_count_, connector_que_size_);
|
||||
auto op = std::make_shared<SkipOp>(skip_count_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -40,7 +40,7 @@ void TakeNode::Print(std::ostream &out) const { out << Name() + "(num_rows:" + s
|
|||
|
||||
// Function to build the TakeOp
|
||||
Status TakeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
|
||||
auto op = std::make_shared<TakeOp>(take_count_, connector_que_size_);
|
||||
auto op = std::make_shared<TakeOp>(take_count_);
|
||||
op->set_total_repeats(GetTotalRepeats());
|
||||
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
|
||||
node_ops->push_back(op);
|
||||
|
|
|
@ -44,7 +44,7 @@ TEST_F(MindDataTestSkipOp, TestSkipOpFuntions) {
|
|||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
// SkipOp
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5, 2);
|
||||
std::shared_ptr<SkipOp> skip_op = std::make_shared<SkipOp>(5);
|
||||
rc = my_tree->AssociateNode(skip_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ def test_profiling_complex_pipeline():
|
|||
data = json.load(f)
|
||||
op_info = data["op_info"]
|
||||
assert len(op_info) == 5
|
||||
for i in range(5):
|
||||
for i in range(4):
|
||||
assert "size" in op_info[i]["metrics"]["output_queue"]
|
||||
assert "length" in op_info[i]["metrics"]["output_queue"]
|
||||
assert "throughput" in op_info[i]["metrics"]["output_queue"]
|
||||
|
|
Loading…
Reference in New Issue