AutoTune support non-sink mode

Add time sts to iterator tracing
This commit is contained in:
hesham 2021-11-24 10:55:54 -05:00
parent a0ab39248d
commit cf3ff77cfd
12 changed files with 87 additions and 171 deletions

View File

@ -136,11 +136,7 @@ Status AutoTune::RunIteration() {
// Run every epoch
if ((profiling_manager_->GetNumOfProfiledEpochs()) >= cur_epoch_) {
MS_LOG(INFO) << "Run AutoTune at epoch #" << cur_epoch_;
if (IsSink()) {
RETURN_IF_NOT_OK(RunIterationSink());
} else {
RETURN_IF_NOT_OK(RunIterationNonSink());
}
RETURN_IF_NOT_OK(RunIterationEpoch());
++cur_epoch_;
}
return Status::OK();
@ -154,7 +150,7 @@ Status AutoTune::RecordPipelineTime() {
<< " ms. The avg pipeline time for all epochs is " << Mean(avg_pipeline_times_) << "ms";
return Status::OK();
}
Status AutoTune::RunIterationSink() {
Status AutoTune::RunIterationEpoch() {
RETURN_IF_NOT_OK(RecordPipelineTime());
bool isBottleneck = false;
RETURN_IF_NOT_OK(IsDSaBottleneck(&isBottleneck));
@ -163,9 +159,6 @@ Status AutoTune::RunIterationSink() {
}
return Status::OK();
}
Status AutoTune::RunIterationNonSink() {
return Status(StatusCode::kMDUnexpectedError, "AutoTune doesn't support non-sink pipeline.");
}
Status AutoTune::IsDSaBottleneck(bool *isBottleneck) {
std::vector<int32_t> sizes;
@ -243,6 +236,8 @@ Status AutoTune::Analyse() {
int64_t queue_capacity = 0;
RETURN_IF_NOT_OK(GetOpConnectorCapacity(op_id, &queue_capacity));
MS_LOG(DEBUG) << "Op (" << ops_[op_id]->NameWithID() << ") CPU=" << cpu_util / num_workers
<< ", in=" << input_queue_util << "out=" << output_queue_util;
// map decisions - queue
if (queue_diff > INPUT_OUTPUT_QUEUE_DIFF_THRESHOLD) {
MS_LOG(WARNING) << "Op (" << ops_[op_id]->NameWithID()

View File

@ -57,13 +57,9 @@ class AutoTune {
/// \return status code
Status RunIteration();
/// The AutoTune logic for sink pipelines that executes every iteration
/// The AutoTune logic for pipelines that executes every epoch
/// \return status code
Status RunIterationSink();
/// The AutoTune logic for non-sink pipelines that executes every iteration
/// \return status code
Status RunIterationNonSink();
Status RunIterationEpoch();
/// Check if the dataset pipeline is the bottleneck
/// \param[out] isBottleneck bool

View File

@ -132,8 +132,8 @@ Status ConnectorSize::GetOpConnectorSize(int32_t op_id, uint64_t start_time, uin
auto start_index = std::distance(ts_.begin(), lower);
auto end_index = std::distance(ts_.begin(), upper);
MS_LOG(INFO) << "start_index: " << start_index << " end_index: " << end_index;
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
CHECK_FAIL_RETURN_UNEXPECTED(start_index <= end_index,
"Expected start_index <= end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
// convert indices to sample_table_ iterator
auto first_iter = sample_table_.begin() + start_index;

View File

@ -129,8 +129,8 @@ Status SystemCpuInfo::SampleAndGetCurrPrevStat(SystemStat *current_stat, SystemS
Status SystemCpuInfo::GetUserCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint8_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< " sys_cpu_util.size: " << sys_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
CHECK_FAIL_RETURN_UNEXPECTED(start_index <= end_index,
"Expected start_index <= end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= sys_cpu_util_.size(),
@ -144,8 +144,8 @@ Status SystemCpuInfo::GetUserCpuUtil(uint64_t start_index, uint64_t end_index, s
Status SystemCpuInfo::GetSysCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint8_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< "sys_cpu_util.size: " << sys_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
CHECK_FAIL_RETURN_UNEXPECTED(start_index <= end_index,
"Expected start_index <= end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= sys_cpu_util_.size(),
@ -279,8 +279,8 @@ Status MDOperatorCpuInfo::GetUserCpuUtil(uint64_t start_index, uint64_t end_inde
std::vector<uint16_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< " op_cpu_util_.size: " << op_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
CHECK_FAIL_RETURN_UNEXPECTED(start_index <= end_index,
"Expected start_index <= end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= op_cpu_util_.size(),
@ -297,8 +297,8 @@ Status MDOperatorCpuInfo::GetUserCpuUtil(uint64_t start_index, uint64_t end_inde
Status MDOperatorCpuInfo::GetSysCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint16_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< " op_cpu_util_.size: " << op_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
CHECK_FAIL_RETURN_UNEXPECTED(start_index <= end_index,
"Expected start_index <= end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= op_cpu_util_.size(),

View File

@ -27,58 +27,6 @@
namespace mindspore {
namespace dataset {
constexpr int32_t CONNECTOR_DEPTH_OFFSET = 0;
Status DatasetIteratorTracing::Init() {
(void)ts_.emplace_back(0);
return Status::OK();
}
Status DatasetIteratorTracing::GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return {StatusCode::kMDUnexpectedError, "Dataset Iterator Tracing does not record pipeline time."};
}
Status DatasetIteratorTracing::GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return {StatusCode::kMDUnexpectedError, "Dataset Iterator Tracing does not record push time."};
}
Status DatasetIteratorTracing::GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return {StatusCode::kMDUnexpectedError, "Dataset Iterator Tracing does not record batch time."};
}
Status DatasetIteratorTracing::GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, CONNECTOR_DEPTH_OFFSET, "value", result);
}
Status DatasetIteratorTracing::GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) {
std::lock_guard<std::mutex> guard(lock_);
auto total_steps = records_.size() / records_per_step_;
MS_LOG(DEBUG) << "start_step: " << start_step << " end_step: " << end_step;
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= total_steps,
"Expected start_step <= total_steps. Got start_step: " + std::to_string(start_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(end_step <= total_steps,
"Expected end_step <= total_steps. Got end_step: " + std::to_string(end_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= end_step,
"Expected start_step <= end_step. Got start_step: " + std::to_string(start_step) +
" end_step: " + std::to_string(end_step));
uint32_t total = end_step - start_step + 1;
uint32_t count = 0U;
for (auto step_num = start_step; step_num <= end_step; step_num++) {
auto idx = (step_num - 1) * records_per_step_ + CONNECTOR_DEPTH_OFFSET;
count += static_cast<uint32_t>(records_[idx].value == 0);
}
*empty_queue_freq = static_cast<float_t>(count) / static_cast<float_t>(total);
return Status::OK();
}
Status DatasetIteratorTracing::GetConnectorCapacity(int32_t start_step, int32_t end_step,
std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, CONNECTOR_DEPTH_OFFSET, "extra_info", result);
}
Path DatasetIteratorTracing::GetFileName(const std::string &dir_path, const std::string &rank_id) {
return Path(dir_path) / Path("dataset_iterator_profiling_" + rank_id + ".txt");
}

View File

@ -23,26 +23,16 @@
namespace mindspore {
namespace dataset {
constexpr int32_t RECORDS_PER_STEP_DATASET_ITERATOR = 1;
class DatasetIteratorTracing : public Tracing {
public:
// Constructor
DatasetIteratorTracing() : Tracing(RECORDS_PER_STEP_DATASET_ITERATOR) {}
DatasetIteratorTracing() = default;
// Destructor
~DatasetIteratorTracing() override = default;
std::string Name() const override { return kDatasetIteratorTracingName; };
Status Init() override;
Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetConnectorCapacity(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) override;
private:
Path GetFileName(const std::string &dir_path, const std::string &rank_id) override;
};

View File

@ -28,60 +28,6 @@
namespace mindspore {
namespace dataset {
constexpr int32_t PUSH_TIME_OFFSET = 0;
constexpr int32_t BATCH_TIME_OFFSET = 1;
constexpr int32_t PIPELINE_TIME_OFFSET = 2;
constexpr int32_t CONNECTOR_DEPTH_OFFSET = 3;
Status DeviceQueueTracing::Init() {
(void)ts_.emplace_back(0);
return Status::OK();
}
Status DeviceQueueTracing::GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, PIPELINE_TIME_OFFSET, "value", result);
}
Status DeviceQueueTracing::GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, PUSH_TIME_OFFSET, "value", result);
}
Status DeviceQueueTracing::GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, BATCH_TIME_OFFSET, "value", result);
}
Status DeviceQueueTracing::GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, CONNECTOR_DEPTH_OFFSET, "value", result);
}
Status DeviceQueueTracing::GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) {
std::lock_guard<std::mutex> guard(lock_);
auto total_steps = records_.size() / records_per_step_;
MS_LOG(DEBUG) << "start_step: " << start_step << " end_step: " << end_step;
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= total_steps,
"Expected start_step <= total_steps. Got start_step: " + std::to_string(start_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(end_step <= total_steps,
"Expected end_step <= total_steps. Got end_step: " + std::to_string(end_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= end_step,
"Expected start_step <= end_step. Got start_step: " + std::to_string(start_step) +
" end_step: " + std::to_string(end_step));
uint32_t total = end_step - start_step + 1;
uint32_t count = 0U;
for (auto step_num = start_step; step_num <= end_step; step_num++) {
auto idx = (step_num - 1) * records_per_step_ + CONNECTOR_DEPTH_OFFSET;
count += static_cast<uint32_t>(records_[idx].value == 0);
}
*empty_queue_freq = static_cast<float_t>(count) / static_cast<float_t>(total);
return Status::OK();
}
Status DeviceQueueTracing::GetConnectorCapacity(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, CONNECTOR_DEPTH_OFFSET, "extra_info", result);
}
Path DeviceQueueTracing::GetFileName(const std::string &dir_path, const std::string &rank_id) {
return Path(dir_path) / Path("device_queue_profiling_" + rank_id + ".txt");
}

View File

@ -24,26 +24,16 @@
namespace mindspore {
namespace dataset {
constexpr int32_t RECORDS_PER_STEP_DEVICE_QUEUE = 4;
class DeviceQueueTracing : public Tracing {
public:
// Constructor
DeviceQueueTracing() : Tracing(RECORDS_PER_STEP_DEVICE_QUEUE) {}
DeviceQueueTracing() = default;
// Destructor
~DeviceQueueTracing() override = default;
std::string Name() const override { return kDeviceQueueTracingName; };
Status Init() override;
Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetConnectorCapacity(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) override;
private:
Path GetFileName(const std::string &dir_path, const std::string &rank_id) override;
};

View File

@ -36,6 +36,11 @@
namespace mindspore {
namespace dataset {
constexpr int32_t PUSH_TIME_OFFSET = 0;
constexpr int32_t BATCH_TIME_OFFSET = 1;
constexpr int32_t PIPELINE_TIME_OFFSET = 2;
constexpr int32_t CONNECTOR_DEPTH_OFFSET = 3;
Status Profiling::Start() {
CHECK_FAIL_RETURN_UNEXPECTED(active_ == false, "Profiling node is already active.");
active_ = true;
@ -106,7 +111,8 @@ void Tracing::Record(const int32_t type, const int32_t extra_info, const int32_t
(void)records_.emplace_back(record);
(void)value_.emplace_back(record.ToString());
// save timestamp per batch
if (records_.size() % records_per_step_ == 0) {
constexpr int32_t RECORDS_PER_STEP = 4;
if (records_.size() % RECORDS_PER_STEP == 0) {
(void)ts_.emplace_back(time_stamp);
}
}
@ -150,9 +156,10 @@ Status Tracing::StepIntervalForTimeRange(uint64_t start_ts, uint64_t end_ts, int
}
Status Tracing::GetRecordEntryFieldValue(int32_t start_step, int32_t end_step, int32_t record_offset,
const std::string field, std::vector<int32_t> *result) {
const std::string &field, std::vector<int32_t> *result) {
std::lock_guard<std::mutex> guard(lock_);
auto total_steps = records_.size() / records_per_step_;
constexpr int32_t RECORDS_PER_STEP = 4;
auto total_steps = records_.size() / RECORDS_PER_STEP;
MS_LOG(DEBUG) << "start_step: " << start_step << " end_step: " << end_step;
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= total_steps,
"Expected start_step <= total_steps. Got start_step: " + std::to_string(start_step) +
@ -165,7 +172,7 @@ Status Tracing::GetRecordEntryFieldValue(int32_t start_step, int32_t end_step, i
" end_step: " + std::to_string(end_step));
for (auto step_num = start_step; step_num <= end_step; step_num++) {
auto idx = (step_num - 1) * records_per_step_ + record_offset;
auto idx = (step_num - 1) * RECORDS_PER_STEP + record_offset;
if (field == "value") {
(void)result->emplace_back(records_[idx].value);
} else if (field == "extra_info") {
@ -178,7 +185,40 @@ Status Tracing::GetRecordEntryFieldValue(int32_t start_step, int32_t end_step, i
return Status::OK();
}
Tracing::Tracing(int32_t records_per_step) : records_per_step_(records_per_step) {}
Status Tracing::GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, PIPELINE_TIME_OFFSET, "value", result);
}
Status Tracing::GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, PUSH_TIME_OFFSET, "value", result);
}
Status Tracing::GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, BATCH_TIME_OFFSET, "value", result);
}
Status Tracing::GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, CONNECTOR_DEPTH_OFFSET, "value", result);
}
Status Tracing::GetConnectorCapacity(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntryFieldValue(start_step, end_step, CONNECTOR_DEPTH_OFFSET, "extra_info", result);
}
Status Tracing::GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) {
std::vector<int32_t> sizes;
RETURN_IF_NOT_OK(GetConnectorSize(start_step, end_step, &sizes));
int32_t total = end_step - start_step + 1;
CHECK_FAIL_RETURN_UNEXPECTED(total <= 0, "Start step is greater than end step.");
uint32_t count = std::count(sizes.begin(), sizes.end(), 0);
*empty_queue_freq = static_cast<float_t>(count) / static_cast<float_t>(total);
return Status::OK();
}
Status Tracing::Init() {
(void)ts_.emplace_back(0);
return Status::OK();
}
// Constructor
ProfilingManager::ProfilingManager()

View File

@ -102,24 +102,24 @@ class Tracing : public Profiling {
// It only includes some common routines.
Status SaveToFile(const std::string &dir_path, const std::string &rank_id) override;
Status ChangeFileMode(const std::string &dir_path, const std::string &rank_id) override;
virtual Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetConnectorCapacity(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) = 0;
Status Init() override;
Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result);
Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result);
Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result);
Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result);
Status GetConnectorCapacity(int32_t start_step, int32_t end_step, std::vector<int32_t> *result);
Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq);
void Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value,
const uint64_t time_stamp);
Status TimeIntervalForStepRange(int32_t start_step, int32_t end_step, uint64_t *start_ts, uint64_t *end_ts);
Status StepIntervalForTimeRange(uint64_t start_ts, uint64_t end_ts, int32_t *start_step, int32_t *end_step);
protected:
explicit Tracing(int32_t records_per_step);
const int32_t records_per_step_;
Tracing() = default;
std::vector<std::string> value_;
std::vector<TracingRecord> records_;
std::vector<uint64_t> ts_; // End time of each step or batch
Status GetRecordEntryFieldValue(int32_t start_step, int32_t end_step, int32_t record_offset, std::string field,
Status GetRecordEntryFieldValue(int32_t start_step, int32_t end_step, int32_t record_offset, const std::string &field,
std::vector<int32_t> *result);
};

View File

@ -232,6 +232,13 @@ Status TreeAdapter::GetNext(TensorRow *row) {
if (!launched_) {
RETURN_IF_NOT_OK(Launch());
}
// Record profiling info
#ifndef ENABLE_SECURITY
uint64_t start_time = 0;
if (tracing_ != nullptr) {
start_time = ProfilingTime::GetCurMilliSecond();
}
#endif
RETURN_IF_NOT_OK(tree_->root()->GetNextRow(row)); // first buf can't be eof or empty buf with none flag
if (row->eoe()) { // return empty tensor if 1st buf is a ctrl buf (no rows)
@ -257,6 +264,10 @@ Status TreeAdapter::GetNext(TensorRow *row) {
cur_batch_num_++;
cur_connector_size_ = tree_->root()->ConnectorSize();
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
// push time is 0ms in dataset iterator since no devices are involved
tracing_->Record(TIME, TDT_PUSH_TIME, cur_batch_num_, 0, end_time);
tracing_->Record(TIME, BATCH_TIME, cur_batch_num_, end_time - start_time, end_time);
tracing_->Record(TIME, PIPELINE_TIME, cur_batch_num_, end_time - start_time, end_time);
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_, end_time);
}
#endif

View File

@ -257,9 +257,9 @@ TEST_F(MindDataTestCallback, TestMultiEpochCallback) {
// config EpochCtrlOp
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(num_repeats);
// start build then launch tree
leaf->SetTotalRepeats(-2);
leaf->SetTotalRepeats(4);
leaf->SetNumRepeatsPerEpoch(2);
map_op->SetTotalRepeats(-2);
map_op->SetTotalRepeats(4);
map_op->SetNumRepeatsPerEpoch(2);
std::shared_ptr<ExecutionTree> tree = Build({leaf, map_op, repeat_op, epoch_ctrl_op});
rc = tree->Prepare();
@ -328,9 +328,9 @@ TEST_F(MindDataTestCallback, TestSelectedCallback) {
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op = std::make_shared<EpochCtrlOp>(2);
// start build then launch tree
leaf->SetTotalRepeats(-2);
leaf->SetTotalRepeats(4);
leaf->SetNumRepeatsPerEpoch(2);
map_op->SetTotalRepeats(-2);
map_op->SetTotalRepeats(4);
map_op->SetNumRepeatsPerEpoch(2);
std::shared_ptr<ExecutionTree> tree = Build({leaf, map_op, repeat_op, epoch_ctrl_op});
rc = tree->Prepare();