Implement autotune api

This commit is contained in:
harshvardhangupta 2021-09-21 17:21:58 -04:00
parent b1e8781d6b
commit 1bb45142c9
24 changed files with 1248 additions and 1349 deletions

View File

@ -151,8 +151,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
#ifndef ENABLE_SECURITY
if (tracing_ != nullptr) {
cur_batch_num_++;
RETURN_IF_NOT_OK(tracing_->Record(static_cast<int32_t>(CONNECTOR_DEPTH), cur_connector_capacity_, cur_batch_num_,
cur_connector_size_, ProfilingTime::GetCurMilliSecond()));
tracing_->Record(static_cast<int32_t>(CONNECTOR_DEPTH), cur_connector_capacity_, cur_batch_num_,
cur_connector_size_, ProfilingTime::GetCurMilliSecond());
}
#endif
return Status::OK();

View File

@ -389,7 +389,7 @@ Status DeviceQueueOp::LaunchParallelCopyThread() {
RETURN_IF_NOT_OK(CircularPool::CreateCircularPool(&pool, -1, kDeviceQueGpuThreadMemory, false, true));
pool_.push_back(pool);
}
gpu_item_connector_ = std::make_unique<GpuItemConnector>(num_workers_, 1, queue_capacity_);
gpu_connector_ = std::make_unique<GpuConnector>(num_workers_, 1, queue_capacity_);
receive_queues_.Init(num_workers_, queue_capacity_);
RETURN_IF_NOT_OK(receive_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
@ -417,73 +417,78 @@ Status DeviceQueueOp::PushDataToGPU() {
RETURN_IF_NOT_OK(tree_->GetProfilingManager()->GetTracingNode(kDeviceQueueTracingName, &node));
profiling_node = std::dynamic_pointer_cast<DeviceQueueTracing>(node);
batch_start_time = ProfilingTime::GetCurMilliSecond();
connector_capacity = gpu_item_connector_->capacity();
connector_capacity = gpu_connector_->capacity();
}
#endif
#ifdef ENABLE_DUMP_IR
md_channel_info_->RecordBatchQueue(gpu_item_connector_->size());
md_channel_info_->RecordBatchQueue(gpu_connector_->size());
md_channel_info_->RecordPreprocessBatch(0);
#endif
std::vector<device::DataItemGpu> items;
RETURN_IF_NOT_OK(gpu_item_connector_->Pop(0, &items));
GpuConnectorItem item;
RETURN_IF_NOT_OK(gpu_connector_->Pop(0, &item));
auto items = std::move(item.data_item);
bool eoe_flag = item.eoe_flag;
int64_t send_batch = 0;
bool is_open = false;
uint32_t handle = INVALID_HANDLE;
auto release_function = std::bind(&DeviceQueueOp::ReleaseData, this, std::placeholders::_1, std::placeholders::_2);
while (!items.empty() && !GpuBufferMgr::GetInstance().IsClosed()) {
#ifdef ENABLE_DUMP_IR
md_channel_info_->RecordBatchQueue(gpu_item_connector_->size());
md_channel_info_->RecordPreprocessBatch(send_batch);
md_channel_info_->RecordPushStartTime();
#endif
if (!is_open) {
std::vector<size_t> data_size;
for (int32_t index = 0; index < items.size(); index++) {
data_size.push_back(items[index].data_len_);
}
handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, data_size, release_function);
if (handle == INVALID_HANDLE) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"[Internal ERROR] Failed to open channel for sending data.");
}
is_open = true;
}
handle = GpuBufferMgr::GetInstance().Open(0, channel_name_, {}, release_function);
if (handle == INVALID_HANDLE) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__,
"[Internal ERROR] Failed to open channel for sending data.");
}
// Data prefetch only when PS mode enables cache.
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_,
items[0].data_type_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__,
"Failed to prefetch data in current PS mode(cache data when sending).");
}
RETURN_IF_NOT_OK(RetryPushData(handle, items));
send_batch++;
while (!(items.empty() && !eoe_flag) && !GpuBufferMgr::GetInstance().IsClosed()) {
if (!eoe_flag) {
#ifdef ENABLE_DUMP_IR
md_channel_info_->RecordBatchQueue(gpu_connector_->size());
md_channel_info_->RecordPreprocessBatch(send_batch);
md_channel_info_->RecordPushStartTime();
#endif
// Data prefetch only when PS mode enables cache.
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_,
items[0].data_type_)) {
return Status(StatusCode::kMDTimeOut, __LINE__, __FILE__,
"Failed to prefetch data in current PS mode(cache data when sending).");
}
RETURN_IF_NOT_OK(RetryPushData(handle, items));
send_batch++;
#ifndef ENABLE_SECURITY
if (is_profiling_enable) {
uint64_t end_time = ProfilingTime::GetCurMilliSecond();
// record push data time
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
int32_t batch_cost = (int32_t)(end_time - batch_start_time);
// record batch time
profiling_node->Record(TIME, BATCH_TIME, send_batch, batch_cost, end_time);
// record pipeline time
profiling_node->Record(TIME, PIPELINE_TIME, send_batch, batch_cost - push_cost, end_time);
batch_start_time = end_time;
// record connector depth
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch, connector_size, end_time);
connector_size = gpu_item_connector_->size();
connector_capacity = gpu_item_connector_->capacity();
}
if (is_profiling_enable) {
uint64_t end_time = ProfilingTime::GetCurMilliSecond();
// record push data time
profiling_node->Record(TIME, TDT_PUSH_TIME, send_batch, push_cost, end_time);
int32_t batch_cost = (int32_t)(end_time - batch_start_time);
// record batch time
profiling_node->Record(TIME, BATCH_TIME, send_batch, batch_cost, end_time);
// record pipeline time
profiling_node->Record(TIME, PIPELINE_TIME, send_batch, batch_cost - push_cost, end_time);
batch_start_time = end_time;
// record connector depth
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, send_batch, connector_size, end_time);
connector_size = gpu_connector_->size();
connector_capacity = gpu_connector_->capacity();
}
#endif
#ifdef ENABLE_DUMP_IR
md_channel_info_->RecordBatchQueue(gpu_item_connector_->size());
md_channel_info_->RecordPreprocessBatch(send_batch);
md_channel_info_->RecordPushEndTime();
md_channel_info_->RecordBatchQueue(gpu_connector_->size());
md_channel_info_->RecordPreprocessBatch(send_batch);
md_channel_info_->RecordPushEndTime();
#endif
if (total_batch_ > 0 && send_batch >= total_batch_) {
break;
}
} else {
#ifndef ENABLE_SECURITY
if (is_profiling_enable) {
tree_->SetEpochEnd();
tree_->GetProfilingManager()->RecordEndOfEpoch(send_batch);
}
#endif
if (total_batch_ > 0 && send_batch >= total_batch_) {
break;
}
if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed()) {
auto rc = gpu_item_connector_->Pop(0, &items);
auto rc = gpu_connector_->Pop(0, &item);
items = std::move(item.data_item);
eoe_flag = item.eoe_flag;
// If the batches send by dataset are more than gpu calculate, gpu will core for no signal notify.
if (rc.IsError()) {
GpuBufferMgr::GetInstance().Close(handle);
@ -543,25 +548,30 @@ Status DeviceQueueOp::WorkerEntry(int32_t worker_id) {
uint32_t batch_num = 0;
RETURN_IF_NOT_OK(receive_queues_[worker_id]->PopFront(&current_row));
while (!current_row.quit() && !GpuBufferMgr::GetInstance().IsClosed()) {
std::vector<device::DataItemGpu> items;
for (int i = 0; i < current_row.size(); i++) {
device::DataItemGpu data_item;
data_item.data_len_ = static_cast<size_t>(current_row[i]->SizeInBytes());
data_item.data_ptr_ = nullptr;
data_item.worker_id_ = worker_id;
items.push_back(data_item);
GpuConnectorItem connector_item = {{}, current_row.eoe()};
if (!connector_item.eoe_flag) {
std::vector<device::DataItemGpu> items;
for (auto &i : current_row) {
device::DataItemGpu data_item;
data_item.data_len_ = static_cast<size_t>(i->SizeInBytes());
data_item.data_ptr_ = nullptr;
data_item.worker_id_ = worker_id;
items.push_back(data_item);
}
RETURN_IF_NOT_OK(MallocForGPUData(&items, current_row, worker_id));
connector_item.data_item = std::move(items);
batch_num++;
} else {
MS_LOG(INFO) << "EOE Detected";
}
RETURN_IF_NOT_OK(MallocForGPUData(&items, current_row, worker_id));
RETURN_IF_NOT_OK(gpu_item_connector_->Add(worker_id, std::move(items)));
batch_num++;
RETURN_IF_NOT_OK(gpu_connector_->Add(worker_id, std::move(connector_item)));
RETURN_IF_NOT_OK(receive_queues_[worker_id]->PopFront(&current_row));
}
MS_LOG(INFO) << "Device queue worker id " << worker_id << "proc " << batch_num << "batch.";
// Add empty vector as quit flag.
std::vector<device::DataItemGpu> items;
RETURN_IF_NOT_OK(gpu_item_connector_->Add(worker_id, std::move(items)));
// Add empty data_item vector with eoe_flag=false as quit flag.
GpuConnectorItem connector_item = {{}, false};
RETURN_IF_NOT_OK(gpu_connector_->Add(worker_id, std::move(connector_item)));
return Status::OK();
}
@ -599,12 +609,12 @@ Status DeviceQueueOp::SendDataToGPU() {
}
}
#ifndef ENABLE_SECURITY
if (current_row.eoe() && tree_->GetProfilingManager()->IsProfilingEnable()) {
tree_->SetEpochEnd();
tree_->GetProfilingManager()->RecordEndOfEpoch(batch_num);
if (current_row.eoe()) {
MS_LOG(INFO) << "EOE Detected";
TensorRow eoe_flag(TensorRow::kFlagEOE);
RETURN_IF_NOT_OK(receive_queues_[num_buf % num_workers_]->Add(std::move(eoe_flag)));
}
#endif
if (!TaskManager::FindMe()->Interrupted() && !GpuBufferMgr::GetInstance().IsClosed()) {
RETURN_IF_NOT_OK(child_iterator_->FetchNextTensorRow(&current_row));
} else {
@ -613,6 +623,7 @@ Status DeviceQueueOp::SendDataToGPU() {
}
for (uint32_t index = 0; index < num_workers_; index++) {
MS_LOG(INFO) << "Adding quit flag to Workers";
TensorRow quit_flag(TensorRow::kFlagQuit);
RETURN_IF_NOT_OK(receive_queues_[num_buf++ % num_workers_]->Add(std::move(quit_flag)));
}

View File

@ -152,7 +152,7 @@ class DeviceQueueOp : public PipelineOp {
QueueList<TensorRow> receive_queues_;
std::vector<std::shared_ptr<MemoryPool>> pool_;
std::unique_ptr<GpuItemConnector> gpu_item_connector_;
std::unique_ptr<GpuConnector> gpu_connector_;
const uint32_t kDeviceQueGpuNumThreads = 2;
const uint32_t kDeviceQueGpuQueueCapacity = 8;
const uint32_t kDeviceQueGpuThreadMemory = 1024;

View File

@ -235,7 +235,7 @@ void ExecutionTree::Iterator::PostOrderTraverse(const std::shared_ptr<DatasetOp>
ExecutionTree::Iterator::Iterator(const std::shared_ptr<DatasetOp> &root) : ind_(0) {
// post-order traverse the tree, if root is null, it return
PostOrderTraverse(root);
nodes_.emplace_back(nullptr);
(void)nodes_.emplace_back(nullptr);
}
// Given the number of workers, launches the worker entry function for each. Essentially a

View File

@ -30,34 +30,41 @@ using mindspore::device::DataItemGpu;
namespace mindspore {
namespace dataset {
class GpuItemConnector : public Connector<std::vector<device::DataItemGpu>> {
struct GpuConnectorItem {
std::vector<device::DataItemGpu> data_item;
bool eoe_flag; // flag to indicate an EOE item in the connector
};
class GpuConnector : public Connector<GpuConnectorItem> {
public:
GpuItemConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity)
: Connector<std::vector<device::DataItemGpu>>(num_producers, num_consumers, queue_capacity) {
GpuConnector(int32_t num_producers, int32_t num_consumers, int32_t queue_capacity)
: Connector<GpuConnectorItem>(num_producers, num_consumers, queue_capacity) {
for (int i = 0; i < num_producers; i++) {
is_queue_finished_.push_back(false);
}
}
~GpuItemConnector() = default;
~GpuConnector() = default;
Status Add(int32_t worker_d, std::vector<device::DataItemGpu> &&element) noexcept {
return Connector<std::vector<device::DataItemGpu>>::Push(worker_d, std::move(element));
Status Add(int32_t worker_d, GpuConnectorItem &&element) noexcept {
return Connector<GpuConnectorItem>::Push(worker_d, std::move(element));
}
Status Pop(int32_t worker_id, std::vector<device::DataItemGpu> *result) noexcept override {
Status Pop(int32_t worker_id, GpuConnectorItem *result) noexcept override {
RETURN_UNEXPECTED_IF_NULL(result);
{
MS_ASSERT(worker_id < num_consumers_);
std::unique_lock<std::mutex> lock(m_);
RETURN_IF_NOT_OK(cv_.Wait(&lock, [this, worker_id]() { return expect_consumer_ == worker_id; }));
if (is_queue_finished_[pop_from_]) {
std::string errMsg = "ERROR: popping from a finished queue in GpuItemConnector";
std::string errMsg = "ERROR: popping from a finished queue in GpuConnector";
RETURN_STATUS_UNEXPECTED(errMsg);
}
RETURN_IF_NOT_OK(queues_[pop_from_]->PopFront(result));
if ((*result).empty()) {
// empty data_item and eoe_flag=false is EOF
if ((*result).data_item.empty() && !(*result).eoe_flag) {
is_queue_finished_[pop_from_] = true;
}
@ -81,5 +88,5 @@ class GpuItemConnector : public Connector<std::vector<device::DataItemGpu>> {
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GPU_ITEM_CONNECTOR_H_
#endif // ENABLE_GPUQUE
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GPU_ITEM_CONNECTOR_H_

View File

@ -1,9 +1,9 @@
add_library(engine-perf OBJECT
profiling.cc
monitor.cc
device_queue_tracing.cc
connector_size.cc
dataset_iterator_tracing.cc
connector_throughput.cc
cpu_sampling.cc
)
add_library(
engine-perf OBJECT
profiling.cc
monitor.cc
device_queue_tracing.cc
connector_size.cc
dataset_iterator_tracing.cc
cpu_sampler.cc
)

View File

@ -15,6 +15,8 @@
*/
#include "minddata/dataset/engine/perf/connector_size.h"
#include <fstream>
#include <algorithm>
#include <memory>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/path.h"
@ -27,10 +29,12 @@ using Qrow = std::vector<int>;
// Sample action
Status ConnectorSize::Sample() {
Qrow cur_row;
std::transform(tree_->begin(), tree_->end(), std::back_inserter(cur_row),
[](DatasetOp &op) { return op.ConnectorSize(); });
(void)std::transform(tree_->begin(), tree_->end(), std::back_inserter(cur_row),
[](DatasetOp &op) { return op.ConnectorSize(); });
std::lock_guard<std::mutex> guard(lock_);
// Push new row of sample
sample_table_.push_back(cur_row);
(void)ts_.emplace_back(ProfilingTime::GetCurMilliSecond());
return Status::OK();
}
@ -70,8 +74,8 @@ Status ConnectorSize::SaveToFile() {
// Traverse the ExecutionTree for JSON node generation
for (auto &node : *tree_) {
std::vector<int32_t> cur_queue_size;
std::transform(sample_table_.begin(), sample_table_.end(), std::back_inserter(cur_queue_size),
[&](const ConnectorSizeSample &sample) { return sample[idx]; });
(void)std::transform(sample_table_.begin(), sample_table_.end(), std::back_inserter(cur_queue_size),
[&](const ConnectorSizeSample &sample) { return sample[idx]; });
if (!path.Exists()) {
json json_node = ParseOpInfo(node, cur_queue_size);
output["op_info"].push_back(json_node);
@ -102,5 +106,37 @@ Status ConnectorSize::Init(const std::string &dir_path, const std::string &devic
}
Status ConnectorSize::Analyze() { return Status::OK(); }
Status ConnectorSize::GetOpConnectorSize(int32_t op_id, uint64_t start_time, uint64_t end_time,
std::vector<int32_t> *result) {
MS_LOG(DEBUG) << "Op_id: " << op_id << " start_ts: " << start_time << " end_ts: " << end_time;
CHECK_FAIL_RETURN_UNEXPECTED(start_time < end_time,
"Expected start_time < end_time. Got start_ts: " + std::to_string(start_time) +
" end_ts: " + std::to_string(end_time));
std::lock_guard<std::mutex> guard(lock_);
CHECK_FAIL_RETURN_UNEXPECTED(
ts_.size() == sample_table_.size(),
"Expected ts_.size() == sample_table_.size(). Got ts_.size: " + std::to_string(ts_.size()) +
" sample_table_.size: " + std::to_string(sample_table_.size()));
// find first ts that is not less than start_ts
auto lower = std::lower_bound(ts_.begin(), ts_.end(), start_time);
// find first ts that is greater than end_ts
auto upper = std::upper_bound(ts_.begin(), ts_.end(), end_time);
// get ts_ indices
auto start_index = std::distance(ts_.begin(), lower);
auto end_index = std::distance(ts_.begin(), upper);
MS_LOG(INFO) << "start_index: " << start_index << " end_index: " << end_index;
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
// convert indices to sample_table_ iterator
auto first_iter = sample_table_.begin() + start_index;
auto last_iter = sample_table_.begin() + end_index;
// op_id corresponds to the index in sample vector
(void)std::transform(first_iter, last_iter, std::back_inserter(*result),
[&](const ConnectorSizeSample &sample) { return sample[op_id]; });
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -40,6 +40,7 @@ class ConnectorSize : public Sampling {
// A circular buffer will be implemented in the future to make this table more flexible.
using ConnectorSizeSample = std::vector<int>;
using ConnectorSizeSampleTable = std::vector<ConnectorSizeSample>;
using Timestamps = std::vector<uint64_t>;
public:
explicit ConnectorSize(ExecutionTree *tree) : tree_(tree) {}
@ -62,13 +63,17 @@ class ConnectorSize : public Sampling {
json ParseOpInfo(const DatasetOp &node, const std::vector<int32_t> &size);
// Change file mode after save throughput data
Status ChangeFileMode() { return Status::OK(); }
Status ChangeFileMode() override { return Status::OK(); }
Status Analyze() override;
// Get the vector of connector sizes of given op for samples taken between start and end time
Status GetOpConnectorSize(int32_t op_id, uint64_t start_time, uint64_t end_time, std::vector<int32_t> *result);
private:
ExecutionTree *tree_ = nullptr; // ExecutionTree pointer
ConnectorSizeSampleTable sample_table_; // Dataset structure to store all samples of connector size sampling
Timestamps ts_; // time of sample
};
} // namespace dataset

View File

@ -1,154 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/stat.h>
#include <iterator>
#include <algorithm>
#include <memory>
#include <string>
#include <nlohmann/json.hpp>
#include "minddata/dataset/engine/perf/connector_throughput.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
// temporary helper
int ConnectorThroughput::InitNodes() {
if (tree_ == nullptr) {
return 0;
}
auto it = (*tree_).begin();
return it.NumNodes();
}
// Sample action
Status ConnectorThroughput::Sample() {
std::vector<int64_t> out_row_count_row(n_nodes_);
std::vector<double> throughput_row(n_nodes_);
TimePoint cur_time; // initialised inside the loop, used outside the loop to update prev sample time.
auto col = 0;
for (const auto &node : *tree_) {
auto cur_out_rows_count = node.ConnectorOutRowsCount();
out_row_count_row[col] = cur_out_rows_count;
auto sz = timestamps_.size();
cur_time = std::chrono::steady_clock::now();
double data_time = 0;
if (sz > 1) {
auto full_time =
std::chrono::duration_cast<std::chrono::microseconds>(timestamps_[0][sz - 1] - timestamps_[0][sz - 2]);
data_time = std::chrono::duration<double>(full_time).count();
}
auto prev_out_rows_count = out_row_count_table_[col][out_row_count_table_.size() - 1];
if (data_time != 0) {
const int32_t multiplier = 1000;
auto thr = (cur_out_rows_count - prev_out_rows_count) / (multiplier * data_time);
throughput_row[col] = thr;
} else {
throughput_row[col] = 0;
}
col++;
}
std::vector<TimePoint> v = {cur_time}; // temporary fix
timestamps_.AddSample(v);
// Push new row of sample
out_row_count_table_.AddSample(out_row_count_row);
throughput_.AddSample(throughput_row);
return Status::OK();
}
json ConnectorThroughput::ParseOpInfo(const DatasetOp &node, const std::vector<double> &thr) {
auto children = node.Children();
std::vector<int32_t> children_id;
std::transform(children.begin(), children.end(), std::back_inserter(children_id),
[](const std::shared_ptr<DatasetOp> &op) -> int32_t { return op ? op->id() : 0; });
json json_node;
json_node["op_id"] = node.id();
json_node["op_type"] = node.Name();
json_node["num_workers"] = node.NumWorkers();
json metrics;
// DeviceQueueOp is a special op,it is not inlined but its output queue is invalid.
// So we should not output its connector throughput.
if (!node.inlined() && node.Name() != "DeviceQueueOp") {
metrics["output_queue"] = {{"throughput", thr}};
}
json_node["metrics"] = metrics;
if (!children_id.empty()) {
json_node["children"] = children_id;
}
return json_node;
}
// Save profiling data to file
// If the file is already exist (created by other sampling node), simply add the data to metrics field.
Status ConnectorThroughput::SaveToFile() {
json output;
RETURN_IF_NOT_OK(ReadJson(&output));
Path path = Path(file_path_);
// Traverse the ExecutionTree for JSON node generation
int col = 0;
for (auto &node : *tree_) {
std::vector<double> throughput;
if (throughput_.size() > col) {
for (auto i = 0; i < throughput_[col].size(); i++) {
throughput.push_back(throughput_[col][i]);
}
}
if (!path.Exists()) {
json json_node = ParseOpInfo(node, throughput);
output["op_info"].push_back(json_node);
} else {
if (!node.inlined() && node.Name() != "DeviceQueueOp") {
auto &ops_data = output["op_info"];
ops_data[col]["metrics"]["output_queue"]["throughput"] = throughput;
}
}
col++;
}
// Discard the content of the file when opening.
std::ofstream os(file_path_, std::ios::trunc);
os << output;
os.close();
return Status::OK();
}
Status ConnectorThroughput::Init(const std::string &dir_path, const std::string &device_id) {
file_path_ = (Path(dir_path) / Path("pipeline_profiling_" + device_id + ".json")).ToString();
Path path = Path(file_path_);
// Remove the file if it exists (from prior profiling usage)
RETURN_IF_NOT_OK(path.Remove());
return Status::OK();
}
Status ConnectorThroughput::ChangeFileMode() {
if (file_path_.empty()) {
return Status::OK();
}
if (chmod(common::SafeCStr(file_path_), S_IRUSR | S_IWUSR) == -1) {
std::string err_str = "Change file mode failed," + file_path_;
return Status(StatusCode::kMDUnexpectedError, err_str);
}
return Status::OK();
}
Status ConnectorThroughput::Analyze() { return Status::OK(); }
} // namespace dataset
} // namespace mindspore

View File

@ -1,92 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_THROUGHPUT_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_THROUGHPUT_H
#include <vector>
#include <chrono>
#include <fstream>
#include <string>
#include <nlohmann/json.hpp>
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/perf_data.h"
#include "minddata/dataset/engine/perf/cyclic_array.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/execution_tree.h"
using json = nlohmann::json;
namespace mindspore {
namespace dataset {
// Connector throughput samples the output connector size of each op in the pipeline.
// For the description of the data structure see perf_data.h
// It support JSON serialization for external usage.
class ConnectorThroughput : public Sampling {
using OutRowCount = PerfData<CyclicArray<int64_t>>;
using Throughput = PerfData<CyclicArray<double>>;
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
using TimeStamps = PerfData<CyclicArray<TimePoint>>;
public:
explicit ConnectorThroughput(ExecutionTree *tree, int64_t max_rows = 1000000)
: tree_(tree),
max_rows_(max_rows),
n_nodes_(InitNodes()),
out_row_count_table_(OutRowCount(max_rows_, n_nodes_)),
throughput_(Throughput(max_rows_, n_nodes_)),
timestamps_(TimeStamps(max_rows_, 1)) {
timestamps_.AddSample(std::vector<TimePoint>(1));
out_row_count_table_.AddSample(std::vector<int64_t>(n_nodes_));
}
/// \brief Destructor
~ConnectorThroughput() = default;
// Driver function for connector size sampling.
// This function samples the connector size of every nodes within the ExecutionTree
Status Sample() override;
// Traverse the tree nodes and count them
int InitNodes();
std::string Name() const override { return name_; };
// Save sampling data to file
// @return Status The status code returned
Status SaveToFile() override;
Status Init(const std::string &dir_path, const std::string &device_id) override;
json ParseOpInfo(const DatasetOp &node, const std::vector<double> &thr);
Status ChangeFileMode() override;
Status Analyze() override;
private:
ExecutionTree *tree_ = nullptr; // ExecutionTree pointer
int64_t max_rows_;
int32_t n_nodes_;
OutRowCount out_row_count_table_;
Throughput throughput_;
TimeStamps timestamps_;
std::string name_ = kConnectorThroughputSamplingName;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CONNECTOR_THROUGHPUT_H

View File

@ -0,0 +1,511 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/perf/cpu_sampler.h"
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
#include <sys/syscall.h>
#endif
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <utility>
#include <fstream>
#include <memory>
#include <string>
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
using json = nlohmann::json;
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
#define USING_LINUX
#endif
#if defined(USING_LINUX)
int32_t SystemCpuInfo::num_cpu_ = get_nprocs_conf();
#else
int32_t SystemCpuInfo::num_cpu_ = 0;
#endif
Status SystemCpuInfo::ParseCpuInfo(const std::string &str) {
SystemStat system_cpu_stat;
uint64_t nice = 0;
uint64_t irq = 0;
uint64_t softirq = 0;
if (sscanf_s(str.c_str(), "%*s %lu %lu %lu %lu %lu %lu %lu", &system_cpu_stat.user_stat, &nice,
&system_cpu_stat.sys_stat, &system_cpu_stat.idle_stat, &system_cpu_stat.io_stat, &irq,
&softirq) == EOF) {
return Status(StatusCode::kMDUnexpectedError, "Get System CPU failed.");
}
system_cpu_stat.total_stat = system_cpu_stat.user_stat + nice + system_cpu_stat.sys_stat + system_cpu_stat.idle_stat +
system_cpu_stat.io_stat + irq + softirq;
SystemUtil system_cpu_util = {0, 0, 0, 0};
// Calculate the utilization from the second sampling
if (!first_sample_) {
system_cpu_util.user_utilization = round((system_cpu_stat.user_stat - prev_sys_stat_.user_stat) * 1.0 /
(system_cpu_stat.total_stat - prev_sys_stat_.total_stat) * 100);
system_cpu_util.sys_utilization = round((system_cpu_stat.sys_stat - prev_sys_stat_.sys_stat) * 1.0 /
(system_cpu_stat.total_stat - prev_sys_stat_.total_stat) * 100);
system_cpu_util.io_utilization = round((system_cpu_stat.io_stat - prev_sys_stat_.io_stat) * 1.0 /
(system_cpu_stat.total_stat - prev_sys_stat_.total_stat) * 100);
system_cpu_util.idle_utilization = round((system_cpu_stat.idle_stat - prev_sys_stat_.idle_stat) * 1.0 /
(system_cpu_stat.total_stat - prev_sys_stat_.total_stat) * 100);
}
// append the 0 util as well to maintain sys_cpu_util_.size == ts_.size
(void)sys_cpu_util_.emplace_back(system_cpu_util);
prev_sys_stat_ = system_cpu_stat;
return Status::OK();
}
Status SystemCpuInfo::ParseCtxt(const std::string &str) {
uint64_t ctxt;
if (sscanf_s(str.c_str(), "%*s %lu", &ctxt) == EOF) {
return Status(StatusCode::kMDUnexpectedError, "Get context switch count failed.");
}
// first context switch count will be 0
auto val = first_sample_ ? 0 : ctxt - prev_context_switch_count_;
context_switch_count_.push_back(val);
prev_context_switch_count_ = ctxt;
return Status::OK();
}
Status SystemCpuInfo::ParseRunningProcess(const std::string &str) {
uint32_t running_process;
if (sscanf_s(str.c_str(), "%*s %ud", &running_process) == EOF) {
return Status(StatusCode::kMDUnexpectedError, "Get context switch count failed.");
}
running_process_.push_back(running_process);
return Status::OK();
}
Status SystemCpuInfo::SampleAndGetCurrPrevStat(SystemStat *current_stat, SystemStat *previous_stat) {
std::ifstream file("/proc/stat");
if (!file.is_open()) {
MS_LOG(INFO) << "Failed to open /proc/stat file";
return {StatusCode::kMDUnexpectedError, "Failed to open /proc/stat file"};
}
*previous_stat = prev_sys_stat_;
bool first_line = true;
std::string line;
while (getline(file, line)) {
if (first_line) {
first_line = false;
RETURN_IF_NOT_OK(ParseCpuInfo(line));
}
if (line.find("ctxt") != std::string::npos) {
RETURN_IF_NOT_OK(ParseCtxt(line));
}
if (line.find("procs_running") != std::string::npos) {
RETURN_IF_NOT_OK(ParseRunningProcess(line));
}
}
// after the loop above, prev_sys_stat_ has the current value
*current_stat = prev_sys_stat_;
file.close();
first_sample_ = false;
return Status::OK();
}
Status SystemCpuInfo::GetUserCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint8_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< " sys_cpu_util.size: " << sys_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= sys_cpu_util_.size(),
"Expected end_index <= sys_cpu_util_.size(). Got end_index: " + std::to_string(end_index) +
" sys_cpu_util_.size: " + std::to_string(sys_cpu_util_.size()));
(void)std::transform(sys_cpu_util_.begin() + start_index, sys_cpu_util_.begin() + end_index,
std::back_inserter(*result), [&](const SystemUtil &info) { return info.user_utilization; });
return Status::OK();
}
Status SystemCpuInfo::GetSysCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint8_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< "sys_cpu_util.size: " << sys_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= sys_cpu_util_.size(),
"Expected end_index <= sys_cpu_util_.size(). Got end_index: " + std::to_string(end_index) +
" sys_cpu_util_.size: " + std::to_string(sys_cpu_util_.size()));
(void)std::transform(sys_cpu_util_.begin() + start_index, sys_cpu_util_.begin() + end_index,
std::back_inserter(*result), [&](const SystemUtil &info) { return info.sys_utilization; });
return Status::OK();
}
std::vector<uint8_t> SystemCpuInfo::GetIOCpuUtil() const {
std::vector<uint8_t> io_util;
(void)std::transform(sys_cpu_util_.begin(), sys_cpu_util_.end(), std::back_inserter(io_util),
[&](const SystemUtil &info) { return info.io_utilization; });
return io_util;
}
std::vector<uint8_t> SystemCpuInfo::GetIdleCpuUtil() const {
std::vector<uint8_t> idle_util;
(void)std::transform(sys_cpu_util_.begin(), sys_cpu_util_.end(), std::back_inserter(idle_util),
[&](const SystemUtil &info) { return info.idle_utilization; });
return idle_util;
}
std::vector<uint16_t> TaskCpuInfo::GetSysCpuUtil() const {
std::vector<uint16_t> sys_util;
(void)std::transform(
task_cpu_util_.begin(), task_cpu_util_.end(), std::back_inserter(sys_util), [&](const TaskUtil &info) {
return static_cast<uint16_t>(info.sys_utilization * static_cast<float>(SystemCpuInfo::num_cpu_));
});
return sys_util;
}
std::vector<uint16_t> TaskCpuInfo::GetUserCpuUtil() const {
std::vector<uint16_t> user_util;
(void)std::transform(
task_cpu_util_.begin(), task_cpu_util_.end(), std::back_inserter(user_util), [&](const TaskUtil &info) {
return static_cast<uint16_t>(info.user_utilization * static_cast<float>(SystemCpuInfo::num_cpu_));
});
return user_util;
}
TaskUtil TaskCpuInfo::GetLatestCpuUtil() const {
TaskUtil ret = {0, 0};
if (!task_cpu_util_.empty() && !last_sampling_failed_) {
ret = task_cpu_util_.back();
}
return ret;
}
Status ProcessCpuInfo::Sample(uint64_t total_time_elapsed) {
std::ifstream file("/proc/" + std::to_string(pid_) + "/stat");
if (!file.is_open()) {
MS_LOG(INFO) << "Failed to open /proc/" << pid_ << "/stat/ file";
last_sampling_failed_ = true;
return Status::OK();
}
std::string str;
(void)getline(file, str);
uint64_t utime = 0, stime = 0;
if (sscanf_s(str.c_str(), "%*d %*s %*s %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %lu %lu", &utime, &stime) ==
EOF) {
file.close();
last_sampling_failed_ = true;
return Status(StatusCode::kMDUnexpectedError, "Get device CPU failed.");
}
file.close();
last_sampling_failed_ = false;
if (!first_sample_) {
float user_util = (utime - prev_task_stat_.user_stat) * 1.0 / (total_time_elapsed)*100.0;
float sys_util = (stime - prev_task_stat_.sys_stat) * 1.0 / (total_time_elapsed)*100.0;
(void)task_cpu_util_.emplace_back(TaskUtil{user_util, sys_util});
}
prev_task_stat_.user_stat = utime;
prev_task_stat_.sys_stat = stime;
first_sample_ = false;
return Status::OK();
}
Status ThreadCpuInfo::Sample(uint64_t total_time_elapsed) {
std::ifstream file("/proc/" + std::to_string(pid_) + "/task/" + std::to_string(tid_) + "/stat");
if (!file.is_open()) {
MS_LOG(INFO) << "Failed to open /proc/" << pid_ << "/task/" << tid_ << "/stat file";
last_sampling_failed_ = true;
return Status::OK();
}
std::string str;
(void)getline(file, str);
uint64_t utime;
uint64_t stime;
if (sscanf_s(str.c_str(), "%*d %*s %*s %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %lu %lu", &utime, &stime) ==
EOF) {
file.close();
last_sampling_failed_ = true;
return Status(StatusCode::kMDUnexpectedError, "Get thread CPU failed.");
}
file.close();
last_sampling_failed_ = false;
if (!first_sample_) {
float user_util = ((utime - prev_task_stat_.user_stat) * 1.0 / total_time_elapsed) * 100.0;
float sys_util = ((stime - prev_task_stat_.sys_stat) * 1.0 / total_time_elapsed) * 100.0;
(void)task_cpu_util_.emplace_back(TaskUtil{user_util, sys_util});
}
prev_task_stat_.user_stat = utime;
prev_task_stat_.sys_stat = stime;
first_sample_ = false;
return Status::OK();
}
bool MDOperatorCpuInfo::TaskExists(pid_t id) const { return task_by_id_.find(id) != task_by_id_.end(); }
void MDOperatorCpuInfo::AddTask(const std::shared_ptr<TaskCpuInfo> &task_ptr) {
auto id = task_ptr->GetId();
if (!TaskExists(id)) {
(void)task_by_id_.emplace(id, task_ptr);
}
}
void MDOperatorCpuInfo::CalculateOperatorUtilization() {
OpUtil op_util{0, 0};
for (auto const &[task_id, task_ptr] : task_by_id_) {
MS_LOG(DEBUG) << "Processing task_id: " << task_id;
auto task_util = task_ptr->GetLatestCpuUtil();
op_util.user_utilization += task_util.user_utilization;
op_util.sys_utilization += task_util.sys_utilization;
}
(void)op_cpu_util_.emplace_back(op_util);
}
Status MDOperatorCpuInfo::GetUserCpuUtil(uint64_t start_index, uint64_t end_index,
std::vector<uint16_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< " op_cpu_util_.size: " << op_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= op_cpu_util_.size(),
"Expected end_index <= op_cpu_util_.size(). Got end_index: " + std::to_string(end_index) +
" op_cpu_util_.size: " + std::to_string(op_cpu_util_.size()));
auto first_iter = op_cpu_util_.begin() + start_index;
auto last_iter = op_cpu_util_.begin() + end_index;
(void)std::transform(first_iter, last_iter, std::back_inserter(*result), [&](const OpUtil &info) {
return static_cast<uint16_t>(info.user_utilization * static_cast<float>(SystemCpuInfo::num_cpu_));
});
return Status::OK();
}
Status MDOperatorCpuInfo::GetSysCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint16_t> *result) const {
MS_LOG(DEBUG) << "start_index: " << start_index << " end_index: " << end_index
<< " op_cpu_util_.size: " << op_cpu_util_.size();
CHECK_FAIL_RETURN_UNEXPECTED(start_index < end_index,
"Expected start_index < end_index. Got start_index: " + std::to_string(start_index) +
" end_index: " + std::to_string(end_index));
CHECK_FAIL_RETURN_UNEXPECTED(
end_index <= op_cpu_util_.size(),
"Expected end_index <= op_cpu_util_.size(). Got end_index: " + std::to_string(end_index) +
" op_cpu_util_.size: " + std::to_string(op_cpu_util_.size()));
auto first_iter = op_cpu_util_.begin() + start_index;
auto last_iter = op_cpu_util_.begin() + end_index;
(void)std::transform(first_iter, last_iter, std::back_inserter(*result), [&](const OpUtil &info) {
return static_cast<uint16_t>(info.sys_utilization * static_cast<float>(SystemCpuInfo::num_cpu_));
});
return Status::OK();
}
Status CpuSampler::Sample() {
std::lock_guard<std::mutex> guard(lock_);
// Function to Update TaskList
// Loop through all tasks to find any new threads
// Get all multi-processing Ops from Python only if fetched_all_process = False
// Create new TaskCpuInfo as required and update OpInfo
RETURN_IF_NOT_OK(UpdateTaskList());
// Sample SystemInfo - Update current and move current to previous stat and calc Util
SystemStat current_sys_stat;
SystemStat previous_sys_stat;
RETURN_IF_NOT_OK(sys_cpu_info_.SampleAndGetCurrPrevStat(&current_sys_stat, &previous_sys_stat));
auto total_time_elapsed = current_sys_stat.total_stat - previous_sys_stat.total_stat;
// Call Sample on all
// Read /proc/ files and get stat, calculate util
for (auto &task_ptr : tasks_) {
(void)task_ptr->Sample(total_time_elapsed);
}
// Calculate OperatorCpuInfo
for (auto &[op_id, op_info] : op_info_by_id_) {
MS_LOG(DEBUG) << "Calculate operator cpu utilization for OpId: " << op_id;
op_info.CalculateOperatorUtilization();
}
// Get sampling time.
(void)ts_.emplace_back(ProfilingTime::GetCurMilliSecond());
return Status::OK();
}
Status CpuSampler::UpdateTaskList() {
List<Task> allTasks = tree->AllTasks()->GetTask();
for (auto &task : allTasks) {
int32_t op_id = task.get_operator_id();
// check if the op_info was initialized in Init
auto iter = op_info_by_id_.find(op_id);
if (iter != op_info_by_id_.end()) {
int32_t tid = task.get_linux_id();
if (!iter->second.TaskExists(tid)) {
auto task_cpu_info_ptr = std::make_shared<ThreadCpuInfo>(main_pid_, tid);
(void)tasks_.emplace_back(task_cpu_info_ptr);
iter->second.AddTask(task_cpu_info_ptr);
}
}
}
if (!fetched_all_python_multiprocesses_) {
py::gil_scoped_acquire gil_acquire;
py::module ds = py::module::import("mindspore.dataset.engine.datasets");
py::tuple process_info = ds.attr("_get_operator_process")();
auto sub_process = py::reinterpret_borrow<py::dict>(process_info[0]);
fetched_all_python_multiprocesses_ = py::reinterpret_borrow<py::bool_>(process_info[1]);
// parse dict value
auto op_to_process = toIntMap(sub_process);
for (auto const &[op_id, process_list] : op_to_process) {
for (auto pid : process_list) {
auto iter = op_info_by_id_.find(op_id);
if (iter != op_info_by_id_.end()) {
if (!iter->second.TaskExists(pid)) {
auto task_cpu_info_ptr = std::make_shared<ProcessCpuInfo>(pid);
(void)tasks_.emplace_back(task_cpu_info_ptr);
iter->second.AddTask(task_cpu_info_ptr);
}
}
}
}
}
return Status::OK();
}
Status CpuSampler::Init(const std::string &dir_path, const std::string &device_id) {
#if defined(USING_LINUX)
main_pid_ = syscall(SYS_getpid);
#endif
auto path = Path(dir_path) / Path("minddata_cpu_utilization_" + device_id + ".json");
// remove file if it already exists
RETURN_IF_NOT_OK(path.Remove());
file_path_ = path.ToString();
for (auto iter = tree->begin(); iter != tree->end(); iter++) {
auto op_id = iter->id();
(void)op_info_by_id_.emplace(std::make_pair(op_id, MDOperatorCpuInfo(op_id)));
}
// thread id of main thread is same as the process ID
main_thread_cpu_info_ = std::make_shared<ThreadCpuInfo>(main_pid_, main_pid_);
(void)tasks_.emplace_back(main_thread_cpu_info_);
main_process_cpu_info_ = std::make_shared<ProcessCpuInfo>(main_pid_);
(void)tasks_.emplace_back(main_process_cpu_info_);
return Status::OK();
}
Status CpuSampler::ChangeFileMode() {
if (chmod(common::SafeCStr(file_path_), S_IRUSR | S_IWUSR) == -1) {
std::string err_str = "Change file mode failed," + file_path_;
return Status(StatusCode::kMDUnexpectedError, err_str);
}
return Status::OK();
}
Status CpuSampler::SaveToFile() {
// construct json obj to write to file
json output;
output["cpu_processor_num"] = SystemCpuInfo::num_cpu_;
std::vector<uint8_t> system_user_util, system_sys_util;
// end_index = ts_.size() essentially means to get all sampled points
(void)sys_cpu_info_.GetUserCpuUtil(0, ts_.size(), &system_user_util);
(void)sys_cpu_info_.GetSysCpuUtil(0, ts_.size(), &system_sys_util);
output["device_info"] = {{"context_switch_count", sys_cpu_info_.GetContextSwitchCount()},
{"idle_utilization", sys_cpu_info_.GetIdleCpuUtil()},
{"io_utilization", sys_cpu_info_.GetIOCpuUtil()},
{"sys_utilization", system_sys_util},
{"user_utilization", system_user_util},
{"runnable_process", sys_cpu_info_.GetRunningProcess()}};
// array of op_info json objects
json op_infos;
for (auto &[op_id, op_info] : op_info_by_id_) {
MS_LOG(INFO) << "Processing op_id: " << op_id;
std::vector<uint16_t> user_util, sys_util;
(void)op_info.GetSysCpuUtil(0, ts_.size(), &sys_util);
(void)op_info.GetUserCpuUtil(0, ts_.size(), &user_util);
json op_info_json = {{"metrics", {{"user_utilization", user_util}, {"sys_utilization", sys_util}}},
{"op_id", op_id}};
op_infos.emplace_back(op_info_json);
}
output["op_info"] = op_infos;
output["process_info"] = {{"user_utilization", main_process_cpu_info_->GetUserCpuUtil()},
{"sys_utilization", main_process_cpu_info_->GetSysCpuUtil()}};
output["sampling_interval"] = GlobalContext::config_manager()->monitor_sampling_interval();
output["time_stamp"] = ts_;
// Discard the content of the file when opening.
std::ofstream os(file_path_, std::ios::trunc);
os << output;
os.close();
return Status::OK();
}
Status CpuSampler::Analyze() { return Status::OK(); }
Status CpuSampler::GetOpUserCpuUtil(int32_t op_id, uint64_t start_ts, uint64_t end_ts, std::vector<uint16_t> *result) {
std::lock_guard<std::mutex> guard(lock_);
// find first ts that is not less than start_ts
auto lower = std::lower_bound(ts_.begin(), ts_.end(), start_ts);
// find first ts that is greater than end_ts
auto upper = std::upper_bound(ts_.begin(), ts_.end(), end_ts);
// std::distance is O(1) since vector allows random access
auto start_index = std::distance(ts_.begin(), lower);
auto end_index = std::distance(ts_.begin(), upper);
auto op_info = op_info_by_id_.find(op_id);
CHECK_FAIL_RETURN_UNEXPECTED(op_info != op_info_by_id_.end(), "Op Id: " + std::to_string(op_id) + " not found.");
return op_info->second.GetUserCpuUtil(start_index, end_index, result);
}
Status CpuSampler::GetOpSysCpuUtil(int32_t op_id, uint64_t start_ts, uint64_t end_ts, std::vector<uint16_t> *result) {
std::lock_guard<std::mutex> guard(lock_);
// find first ts that is not less than start_ts
auto lower = std::lower_bound(ts_.begin(), ts_.end(), start_ts);
// find first ts that is greater than end_ts
auto upper = std::upper_bound(ts_.begin(), ts_.end(), end_ts);
// std::distance is O(1) since vector allows random access
auto start_index = std::distance(ts_.begin(), lower);
auto end_index = std::distance(ts_.begin(), upper);
auto op_info = op_info_by_id_.find(op_id);
CHECK_FAIL_RETURN_UNEXPECTED(op_info != op_info_by_id_.end(), "Op Id: " + std::to_string(op_id) + " not found.");
return op_info->second.GetSysCpuUtil(start_index, end_index, result);
}
Status CpuSampler::GetSystemUserCpuUtil(uint64_t start_ts, uint64_t end_ts, std::vector<uint8_t> *result) {
std::lock_guard<std::mutex> guard(lock_);
// find first ts that is not less than start_ts
auto lower = std::lower_bound(ts_.begin(), ts_.end(), start_ts);
// find first ts that is greater than end_ts
auto upper = std::upper_bound(ts_.begin(), ts_.end(), end_ts);
// std::distance is O(1) since vector allows random access
auto start_index = std::distance(ts_.begin(), lower);
auto end_index = std::distance(ts_.begin(), upper);
return sys_cpu_info_.GetUserCpuUtil(start_index, end_index, result);
}
Status CpuSampler::GetSystemSysCpuUtil(uint64_t start_ts, uint64_t end_ts, std::vector<uint8_t> *result) {
std::lock_guard<std::mutex> guard(lock_);
// find first ts that is not less than start_ts
auto lower = std::lower_bound(ts_.begin(), ts_.end(), start_ts);
// find first ts that is greater than end_ts
auto upper = std::upper_bound(ts_.begin(), ts_.end(), end_ts);
// std::distance is O(1) since vector allows random access
auto start_index = std::distance(ts_.begin(), lower);
auto end_index = std::distance(ts_.begin(), upper);
return sys_cpu_info_.GetSysCpuUtil(start_index, end_index, result);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,169 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PERF_CPU_SAMPLER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PERF_CPU_SAMPLER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <nlohmann/json.hpp>
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
namespace mindspore {
namespace dataset {
class ExecutionTree;
typedef struct SystemStat_s {
uint64_t user_stat;
uint64_t sys_stat;
uint64_t io_stat;
uint64_t idle_stat;
uint64_t total_stat;
} SystemStat;
typedef struct SystemUtil_s {
uint8_t user_utilization;
uint8_t sys_utilization;
uint8_t io_utilization;
uint8_t idle_utilization;
} SystemUtil;
typedef struct TaskStat_s {
uint64_t user_stat;
uint64_t sys_stat;
} TaskStat;
struct TaskUtil_s {
float user_utilization;
float sys_utilization;
};
typedef struct TaskUtil_s TaskUtil;
typedef struct TaskUtil_s OpUtil;
class SystemCpuInfo {
public:
SystemCpuInfo() : first_sample_(true), prev_context_switch_count_(0) {}
// Read in current stats and return previous and currently read stats
Status SampleAndGetCurrPrevStat(SystemStat *current_stat, SystemStat *previous_stat);
static int32_t num_cpu_;
const std::vector<uint32_t> &GetRunningProcess() const { return running_process_; }
const std::vector<uint64_t> &GetContextSwitchCount() const { return context_switch_count_; }
Status GetUserCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint8_t> *result) const;
Status GetSysCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint8_t> *result) const;
std::vector<uint8_t> GetIOCpuUtil() const;
std::vector<uint8_t> GetIdleCpuUtil() const;
private:
Status ParseCpuInfo(const std::string &str);
Status ParseCtxt(const std::string &str);
Status ParseRunningProcess(const std::string &str);
SystemStat prev_sys_stat_{}; // last read data /proc/stat file
std::vector<SystemUtil> sys_cpu_util_; // vector of system cpu utilization
std::vector<uint32_t> running_process_; // vector of running processes in system
std::vector<uint64_t> context_switch_count_; // vector of number of context switches between two sampling points
bool first_sample_; // flag to indicate first time sampling
uint64_t prev_context_switch_count_; // last read context switch count from /proc/stat file
};
class TaskCpuInfo {
public:
explicit TaskCpuInfo(pid_t pid) : pid_(pid), first_sample_(true), last_sampling_failed_(false) {}
virtual ~TaskCpuInfo() = default;
virtual Status Sample(uint64_t total_time_elapsed) = 0;
virtual pid_t GetId() = 0;
TaskUtil GetLatestCpuUtil() const;
std::vector<uint16_t> GetSysCpuUtil() const;
std::vector<uint16_t> GetUserCpuUtil() const;
protected:
pid_t pid_;
TaskStat prev_task_stat_;
std::vector<TaskUtil> task_cpu_util_;
bool first_sample_;
bool last_sampling_failed_;
};
class ProcessCpuInfo : public TaskCpuInfo {
public:
explicit ProcessCpuInfo(pid_t pid) : TaskCpuInfo(pid) {}
~ProcessCpuInfo() override = default;
Status Sample(uint64_t total_time_elapsed) override;
pid_t GetId() override { return pid_; }
};
class ThreadCpuInfo : public TaskCpuInfo {
public:
explicit ThreadCpuInfo(pid_t pid, pid_t tid) : TaskCpuInfo(pid), tid_(tid) {}
~ThreadCpuInfo() override = default;
Status Sample(uint64_t total_time_elapsed) override;
pid_t GetId() override { return tid_; }
private:
pid_t tid_;
};
class MDOperatorCpuInfo {
public:
void AddTask(const std::shared_ptr<TaskCpuInfo> &task_ptr);
bool TaskExists(pid_t id) const;
explicit MDOperatorCpuInfo(const int32_t op_id) : id_(op_id) {}
void CalculateOperatorUtilization();
Status GetUserCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint16_t> *result) const;
Status GetSysCpuUtil(uint64_t start_index, uint64_t end_index, std::vector<uint16_t> *result) const;
private:
int32_t id_;
// tid is key for threadinfo, pid is key for processinfo
std::unordered_map<pid_t, std::shared_ptr<TaskCpuInfo>> task_by_id_;
std::vector<OpUtil> op_cpu_util_;
};
class CpuSampler : public Sampling {
using Timestamps = std::vector<uint64_t>;
public:
explicit CpuSampler(ExecutionTree *tree) : fetched_all_python_multiprocesses_(false), tree(tree) {}
~CpuSampler() = default;
Status Sample() override;
Status Init(const std::string &dir_path, const std::string &device_id) override;
Status ChangeFileMode() override;
Status SaveToFile() override;
std::string Name() const override { return kCpuSamplerName; }
Status Analyze() override;
Status GetSystemUserCpuUtil(uint64_t start_ts, uint64_t end_ts, std::vector<uint8_t> *result);
Status GetSystemSysCpuUtil(uint64_t start_ts, uint64_t end_ts, std::vector<uint8_t> *result);
Status GetOpUserCpuUtil(int32_t op_id, uint64_t start_ts, uint64_t end_ts, std::vector<uint16_t> *result);
Status GetOpSysCpuUtil(int32_t op_id, uint64_t start_ts, uint64_t end_ts, std::vector<uint16_t> *result);
private:
Status UpdateTaskList();
bool fetched_all_python_multiprocesses_{};
ExecutionTree *tree = nullptr;
pid_t main_pid_{};
Timestamps ts_;
SystemCpuInfo sys_cpu_info_; // store the system cpu utilization
std::vector<std::shared_ptr<TaskCpuInfo>> tasks_; // vector of all process and thread tasks
std::shared_ptr<ThreadCpuInfo> main_thread_cpu_info_;
std::shared_ptr<ProcessCpuInfo> main_process_cpu_info_;
std::unordered_map<int32_t, MDOperatorCpuInfo> op_info_by_id_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_PERF_CPU_SAMPLER_H_

View File

@ -1,699 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/perf/cpu_sampling.h"
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
#include <sys/syscall.h>
#endif
#include <cmath>
#include <cstdio>
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/path.h"
using json = nlohmann::json;
namespace mindspore {
namespace dataset {
bool BaseCpu::fetched_all_process_shared_ = false;
std::unordered_map<int32_t, std::vector<pid_t>> BaseCpu::op_process_shared_ = {};
#if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
#define USING_LINUX
#endif
BaseCpu::BaseCpu() {
pre_cpu_stat_.user_stat_ = 0;
pre_cpu_stat_.sys_stat_ = 0;
pre_cpu_stat_.io_stat_ = 0;
pre_cpu_stat_.idle_stat_ = 0;
pre_cpu_stat_.total_stat_ = 0;
fetched_all_process_ = false;
pre_fetched_state_ = false;
cpu_processor_num_ = 0;
}
Status BaseCpu::GetTotalCpuTime(uint64_t *total_stat) {
std::ifstream file("/proc/stat");
if (!file.is_open()) {
MS_LOG(INFO) << "Open CPU file failed when collect CPU information";
return Status::OK();
}
std::string str;
getline(file, str);
uint64_t user = 0, sys = 0, idle = 0, iowait = 0, nice = 0, irq = 0, softirq = 0;
if (sscanf_s(str.c_str(), "%*s %lu %lu %lu %lu %lu %lu %lu", &user, &nice, &sys, &idle, &iowait, &irq, &softirq) ==
EOF) {
file.close();
return Status(StatusCode::kMDUnexpectedError, "Get device CPU failed.");
}
file.close();
*total_stat = user + nice + sys + idle + iowait + irq + softirq;
return Status::OK();
}
Status DeviceCpu::ParseCpuInfo(const std::string &str) {
CpuStat cpu_stat;
uint64_t nice = 0;
uint64_t irq = 0;
uint64_t softirq = 0;
if (sscanf_s(str.c_str(), "%*s %lu %lu %lu %lu %lu %lu %lu", &cpu_stat.user_stat_, &nice, &cpu_stat.sys_stat_,
&cpu_stat.idle_stat_, &cpu_stat.io_stat_, &irq, &softirq) == EOF) {
return Status(StatusCode::kMDUnexpectedError, "Get device CPU failed.");
}
cpu_stat.total_stat_ =
cpu_stat.user_stat_ + nice + cpu_stat.sys_stat_ + cpu_stat.idle_stat_ + cpu_stat.io_stat_ + irq + softirq;
// Calculate the utilization from the second sampling
if (!first_collect_) {
CpuUtil info;
info.user_utilization_ = round((cpu_stat.user_stat_ - pre_cpu_stat_.user_stat_) * 1.0 /
(cpu_stat.total_stat_ - pre_cpu_stat_.total_stat_) * 100);
info.sys_utilization_ = round((cpu_stat.sys_stat_ - pre_cpu_stat_.sys_stat_) * 1.0 /
(cpu_stat.total_stat_ - pre_cpu_stat_.total_stat_) * 100);
info.io_utilization_ = round((cpu_stat.io_stat_ - pre_cpu_stat_.io_stat_) * 1.0 /
(cpu_stat.total_stat_ - pre_cpu_stat_.total_stat_) * 100);
info.idle_utilization_ = round((cpu_stat.idle_stat_ - pre_cpu_stat_.idle_stat_) * 1.0 /
(cpu_stat.total_stat_ - pre_cpu_stat_.total_stat_) * 100);
cpu_util_.emplace_back(info);
}
pre_cpu_stat_.user_stat_ = cpu_stat.user_stat_;
pre_cpu_stat_.sys_stat_ = cpu_stat.sys_stat_;
pre_cpu_stat_.io_stat_ = cpu_stat.io_stat_;
pre_cpu_stat_.idle_stat_ = cpu_stat.idle_stat_;
pre_cpu_stat_.total_stat_ = cpu_stat.total_stat_;
return Status::OK();
}
Status DeviceCpu::ParseCtxt(const std::string &str) {
uint64_t ctxt;
if (sscanf_s(str.c_str(), "%*s %lu", &ctxt) == EOF) {
return Status(StatusCode::kMDUnexpectedError, "Get context switch count failed.");
}
// Calculate the utilization from the second sampling
if (!first_collect_) {
context_switch_count_.push_back(ctxt - pre_context_switch_count_);
}
pre_context_switch_count_ = ctxt;
return Status::OK();
}
Status DeviceCpu::ParseRunningProcess(const std::string &str) {
uint32_t running_process;
if (sscanf_s(str.c_str(), "%*s %ud", &running_process) == EOF) {
return Status(StatusCode::kMDUnexpectedError, "Get context switch count failed.");
}
// Drop the first value in order to collect same amount of CPU utilization
if (!first_collect_) {
running_process_.push_back(running_process);
}
return Status::OK();
}
Status DeviceCpu::Collect(const ExecutionTree *tree) {
std::ifstream file("/proc/stat");
if (!file.is_open()) {
MS_LOG(INFO) << "Open CPU file failed when collect CPU information";
return Status::OK();
}
bool first_line = true;
std::string line;
while (getline(file, line)) {
if (first_line) {
first_line = false;
RETURN_IF_NOT_OK(ParseCpuInfo(line));
}
if (line.find("ctxt") != std::string::npos) {
RETURN_IF_NOT_OK(ParseCtxt(line));
}
if (line.find("procs_running") != std::string::npos) {
RETURN_IF_NOT_OK(ParseRunningProcess(line));
}
}
file.close();
first_collect_ = false;
return Status::OK();
}
Status DeviceCpu::Analyze(std::string *name, double *utilization, std::string *extra_message) {
RETURN_UNEXPECTED_IF_NULL(name);
name->clear();
name->append("device_info");
int total_samples = cpu_util_.size();
int sum = 0;
// Only analyze the middle half of the samples
// Starting and ending may be impacted by startup or ending pipeline activities
int start_analyze = total_samples / 4;
int end_analyze = total_samples - start_analyze;
for (int i = start_analyze; i < end_analyze; i++) {
sum += cpu_util_[i].user_utilization_;
sum += cpu_util_[i].sys_utilization_;
}
// Note device utilization is already in range of 0-1, so don't
// need to divide by number of CPUS
if ((end_analyze - start_analyze) > 0) {
*utilization = sum / (end_analyze - start_analyze);
}
return Status::OK();
}
Status DeviceCpu::SaveToFile(const std::string &file_path) {
Path path = Path(file_path);
json output;
if (path.Exists()) {
MS_LOG(DEBUG) << file_path << " exists already";
try {
std::ifstream file(file_path);
file >> output;
} catch (const std::exception &err) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open json file: " + file_path +
", please delete it and try again!");
}
} else {
output["sampling_interval"] = GlobalContext::config_manager()->monitor_sampling_interval();
}
std::vector<int8_t> user_util;
std::transform(cpu_util_.begin(), cpu_util_.end(), std::back_inserter(user_util),
[&](const CpuUtil &info) { return info.user_utilization_; });
std::vector<int8_t> sys_util;
std::transform(cpu_util_.begin(), cpu_util_.end(), std::back_inserter(sys_util),
[&](const CpuUtil &info) { return info.sys_utilization_; });
std::vector<int8_t> io_util;
std::transform(cpu_util_.begin(), cpu_util_.end(), std::back_inserter(io_util),
[&](const CpuUtil &info) { return info.io_utilization_; });
std::vector<int8_t> idle_util;
std::transform(cpu_util_.begin(), cpu_util_.end(), std::back_inserter(idle_util),
[&](const CpuUtil &info) { return info.idle_utilization_; });
output["device_info"] = {{"user_utilization", user_util},
{"sys_utilization", sys_util},
{"io_utilization", io_util},
{"idle_utilization", idle_util},
{"runable_processes", running_process_},
{"context_switch_count", context_switch_count_}};
// Discard the content of the file when opening.
std::ofstream os(file_path, std::ios::trunc);
os << output;
os.close();
MS_LOG(INFO) << "Save device CPU success.";
return Status::OK();
}
Status OperatorCpu::ParseCpuInfo(int32_t op_id, int64_t thread_id,
std::unordered_map<int32_t, std::unordered_map<int64_t, CpuOpStat>> *op_stat) {
RETURN_UNEXPECTED_IF_NULL(op_stat);
pid_t pid = 0;
#if defined(USING_LINUX)
pid = syscall(SYS_getpid);
#endif
std::string stat_path = "/proc/" + std::to_string(pid) + "/task/" + std::to_string(thread_id) + "/stat";
// Judge whether file exist first
Path temp_path(stat_path);
if (!temp_path.Exists()) {
(*op_stat)[op_id][thread_id].user_stat_ = 0;
(*op_stat)[op_id][thread_id].sys_stat_ = 0;
return Status(StatusCode::kMDFileNotExist);
}
std::ifstream file(stat_path);
if (!file.is_open()) {
MS_LOG(INFO) << "Open CPU file failed when collect CPU information";
return Status::OK();
}
std::string str;
getline(file, str);
uint64_t utime;
uint64_t stime;
if (sscanf_s(str.c_str(), "%*d %*s %*s %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %lu %lu", &utime, &stime) ==
EOF) {
file.close();
return Status(StatusCode::kMDUnexpectedError, "Get device CPU failed.");
}
file.close();
(*op_stat)[op_id][thread_id].user_stat_ = utime;
(*op_stat)[op_id][thread_id].sys_stat_ = stime;
return Status::OK();
}
Status OperatorCpu::Collect(const ExecutionTree *tree) {
RETURN_UNEXPECTED_IF_NULL(tree);
if (first_collect_) {
for (auto iter = tree->begin(); iter != tree->end(); ++iter) {
id_count_++;
op_name_[iter->id()] = iter->NameWithID();
op_parallel_workers_[iter->id()] = iter->NumWorkers();
}
#if defined(USING_LINUX)
cpu_processor_num_ = get_nprocs_conf();
#endif
}
// Obtain the op and thread mapping
op_thread_.clear();
List<Task> allTasks = tree->AllTasks()->GetTask();
for (auto &task1 : allTasks) {
int32_t op_id = task1.get_operator_id();
op_thread_[op_id].emplace_back(task1.get_linux_id());
}
// add process id into op_thread
if (!fetched_all_process_) {
{
py::gil_scoped_acquire gil_acquire;
py::module ds = py::module::import("mindspore.dataset.engine.datasets");
py::tuple process_info = ds.attr("_get_operator_process")();
py::dict sub_process = py::reinterpret_borrow<py::dict>(process_info[0]);
fetched_all_process_ = py::reinterpret_borrow<py::bool_>(process_info[1]);
// parse dict value
op_process_ = toIntMap(sub_process);
BaseCpu::op_process_shared_ = op_process_;
BaseCpu::fetched_all_process_shared_ = fetched_all_process_;
}
// judge whether there is device_que operator, if so operator id may need increase by one, temp use directly
for (auto item : op_process_) {
if (!item.second.empty()) {
if (op_thread_.find(item.first) != op_thread_.end()) {
op_thread_[item.first].insert(op_thread_[item.first].end(), item.second.begin(), item.second.end());
} else {
op_thread_[item.first] = item.second;
}
}
}
}
uint64_t total_stat_;
RETURN_IF_NOT_OK(GetTotalCpuTime(&total_stat_));
std::vector<CpuOpUtil> cpu_step_util_;
std::unordered_map<int32_t, std::unordered_map<int64_t, CpuOpStat>> op_stat_;
if (!first_collect_) {
// obtain all the op id in current tasks
std::vector<int32_t> total_op_id;
(void)std::transform(op_thread_.begin(), op_thread_.end(), std::back_inserter(total_op_id),
[](const auto &iter) { return iter.first; });
// iter all the op, and obtain the CPU utilization of each operator
for (auto op_id = -1; op_id < id_count_; op_id++) {
float user_util = 0, sys_util = 0;
auto iter = std::find(total_op_id.begin(), total_op_id.end(), op_id);
if (iter != total_op_id.end()) {
for (auto thread_id : op_thread_[op_id]) {
if (ParseCpuInfo(op_id, thread_id, &op_stat_) == Status::OK()) {
user_util += (op_stat_[op_id][thread_id].user_stat_ - pre_op_stat_[op_id][thread_id].user_stat_) * 1.0 /
(total_stat_ - pre_total_stat_) * 100;
sys_util += (op_stat_[op_id][thread_id].sys_stat_ - pre_op_stat_[op_id][thread_id].sys_stat_) * 1.0 /
(total_stat_ - pre_total_stat_) * 100;
}
}
}
CpuOpUtil info;
info.op_id_ = op_id;
info.sys_utilization_ = sys_util;
info.user_utilization_ = user_util;
cpu_step_util_.emplace_back(info);
}
cpu_op_util_.emplace_back(cpu_step_util_);
} else {
// mainly obtain the init CPU execute time in first collect
for (const auto &iter : op_thread_) {
int32_t op_id = iter.first;
for (auto thread_id_ : iter.second) {
// ParseCpuInfo may execute failed for cpu data not ready, but we still get next thread cpu info
(void)ParseCpuInfo(op_id, thread_id_, &op_stat_);
}
}
}
// copy current op_stat into pre_op_stat
pre_op_stat_ = op_stat_;
pre_total_stat_ = total_stat_;
first_collect_ = false;
return Status::OK();
}
Status OperatorCpu::Analyze(std::string *name, double *utilization, std::string *extra_message) {
RETURN_UNEXPECTED_IF_NULL(name);
RETURN_UNEXPECTED_IF_NULL(extra_message);
int total_samples = cpu_op_util_.size();
// Only analyze the middle half of the samples
// Starting and ending may be impacted by startup or ending pipeline activities
constexpr int64_t sample_sections = 4;
int64 start_analyze = total_samples / sample_sections;
int64 end_analyze = total_samples - start_analyze;
double op_util = 0;
*utilization = 0;
// start loop from 0 was as don't want to analyze op -1
for (auto op_id = 0; op_id < id_count_; op_id++) {
int64 sum = 0;
int64 index = op_id + 1;
for (int i = start_analyze; i < end_analyze; i++) {
sum += cpu_op_util_[i][index].user_utilization_;
sum += cpu_op_util_[i][index].sys_utilization_;
}
if ((end_analyze - start_analyze) > 0) {
op_util = 1.0 * sum * cpu_processor_num_ / (op_parallel_workers_[op_id] * (end_analyze - start_analyze));
}
if (op_util > *utilization) {
*utilization = op_util;
name->clear();
(void)name->append(op_name_[op_id]);
}
(void)extra_message->append(op_name_[op_id] + " utilization per thread: " + std::to_string(op_util) + "% (" +
std::to_string(op_parallel_workers_[op_id]) + " parallel_workers); ");
}
return Status::OK();
}
Status OperatorCpu::SaveToFile(const std::string &file_path) {
Path path = Path(file_path);
json output;
if (path.Exists()) {
MS_LOG(DEBUG) << file_path << "already exist.";
try {
std::ifstream file(file_path);
file >> output;
} catch (const std::exception &err) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open json file: " + file_path +
", please delete it and try again!");
}
}
uint8_t index = 0;
json OpWriter;
for (auto op_id = -1; op_id < id_count_; op_id++) {
std::vector<uint16_t> user_util;
std::vector<uint16_t> sys_util;
std::transform(
cpu_op_util_.begin(), cpu_op_util_.end(), std::back_inserter(user_util),
[&](const std::vector<CpuOpUtil> &info) { return int16_t(info[index].user_utilization_ * cpu_processor_num_); });
std::transform(
cpu_op_util_.begin(), cpu_op_util_.end(), std::back_inserter(sys_util),
[&](const std::vector<CpuOpUtil> &info) { return int16_t(info[index].sys_utilization_ * cpu_processor_num_); });
json per_op_info = {{"metrics", {{"user_utilization", user_util}, {"sys_utilization", sys_util}}},
{"op_id", op_id}};
OpWriter.emplace_back(per_op_info);
index++;
}
output["op_info"] = OpWriter;
// Discard the content of the file when opening.
std::ofstream os(file_path, std::ios::trunc);
os << output;
os.close();
MS_LOG(INFO) << "Save device CPU success.";
return Status::OK();
}
Status ProcessCpu::ParseCpuInfo() {
uint64_t total_stat_;
RETURN_IF_NOT_OK(GetTotalCpuTime(&total_stat_));
if (!pre_fetched_state_) {
process_id_.clear();
pid_t main_pid = 0;
#if defined(USING_LINUX)
main_pid = syscall(SYS_getpid);
#endif
process_id_.emplace_back(main_pid);
op_process_ = BaseCpu::op_process_shared_;
fetched_all_process_ = BaseCpu::fetched_all_process_shared_;
for (const auto &item : op_process_) {
for (const auto &id : item.second) {
process_id_.emplace_back(id);
}
}
}
float user_util = 0, sys_util = 0;
for (const auto &pid : process_id_) {
std::string stat_path = "/proc/" + std::to_string(pid) + "/stat";
std::ifstream file(stat_path);
if (!file.is_open()) {
MS_LOG(INFO) << "Open CPU file failed when collect CPU information";
continue;
}
std::string str;
getline(file, str);
uint64_t user = 0, sys = 0;
if (sscanf_s(str.c_str(), "%*d %*s %*s %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %*lu %lu %lu", &user, &sys) ==
EOF) {
file.close();
return Status(StatusCode::kMDUnexpectedError, "Get device CPU failed.");
}
file.close();
// Calculate the utilization from the second sampling
if (!first_collect_ && (pre_process_stat_.find(pid) != pre_process_stat_.end())) {
user_util += (user - pre_process_stat_[pid].user_stat_) * 1.0 / (total_stat_ - pre_total_stat_) * 100;
sys_util += (sys - pre_process_stat_[pid].sys_stat_) * 1.0 / (total_stat_ - pre_total_stat_) * 100;
}
pre_process_stat_[pid].user_stat_ = user;
pre_process_stat_[pid].sys_stat_ = sys;
}
if (!first_collect_) {
CpuProcessUtil info;
info.user_utilization_ = user_util;
info.sys_utilization_ = sys_util;
process_util_.emplace_back(info);
}
pre_total_stat_ = total_stat_;
first_collect_ = false;
pre_fetched_state_ = fetched_all_process_;
return Status::OK();
}
Status ProcessCpu::Collect(const ExecutionTree *tree) {
RETURN_UNEXPECTED_IF_NULL(tree);
if (first_collect_) {
#if defined(USING_LINUX)
cpu_processor_num_ = get_nprocs_conf();
#endif
}
RETURN_IF_NOT_OK(ParseCpuInfo());
return Status::OK();
}
Status ProcessCpu::Analyze(std::string *name, double *utilization, std::string *extra_message) {
RETURN_UNEXPECTED_IF_NULL(name);
RETURN_UNEXPECTED_IF_NULL(utilization);
RETURN_UNEXPECTED_IF_NULL(extra_message);
name->clear();
name->append("process_info");
int total_samples = process_util_.size();
int64 sum = 0;
// Only analyze the middle half of the samples
// Starting and ending may be impacted by startup or ending pipeline activities
constexpr int64_t sample_sections = 4;
int64 start_analyze = total_samples / sample_sections;
int64 end_analyze = total_samples - start_analyze;
for (int i = start_analyze; i < end_analyze; i++) {
sum += process_util_[i].user_utilization_;
sum += process_util_[i].sys_utilization_;
}
if ((end_analyze - start_analyze) > 0) {
*utilization = sum / (end_analyze - start_analyze);
}
return Status::OK();
}
Status ProcessCpu::SaveToFile(const std::string &file_path) {
Path path = Path(file_path);
json output;
if (path.Exists()) {
MS_LOG(DEBUG) << file_path << "already exist.";
try {
std::ifstream file(file_path);
file >> output;
} catch (const std::exception &err) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open json file: " + file_path +
", please delete it and try again!");
}
} else {
output["sampling_interval"] = GlobalContext::config_manager()->monitor_sampling_interval();
}
std::vector<int16_t> user_util;
std::transform(process_util_.begin(), process_util_.end(), std::back_inserter(user_util),
[&](const CpuProcessUtil &info) { return uint16_t(info.user_utilization_ * cpu_processor_num_); });
std::vector<int16_t> sys_util;
std::transform(process_util_.begin(), process_util_.end(), std::back_inserter(sys_util),
[&](const CpuProcessUtil &info) { return uint16_t(info.sys_utilization_ * cpu_processor_num_); });
output["process_info"] = {{"user_utilization", user_util}, {"sys_utilization", sys_util}};
output["cpu_processor_num"] = cpu_processor_num_;
// Discard the content of the file when opening.
std::ofstream os(file_path, std::ios::trunc);
os << output;
os.close();
MS_LOG(INFO) << "Save process CPU success.";
return Status::OK();
}
Status CpuSampling::CollectTimeStamp() {
time_stamp_.emplace_back(ProfilingTime::GetCurMilliSecond());
return Status::OK();
}
// Sample action
Status CpuSampling::Sample() {
// Collect cpu information
for (auto cpu : cpu_) {
RETURN_IF_NOT_OK(cpu->Collect(this->tree_));
}
// Collect time stamp
RETURN_IF_NOT_OK(CollectTimeStamp());
return Status::OK();
}
Status CpuSampling::SaveTimeStampToFile() {
// Save time stamp to json file
// If the file is already exist, simply add the data to corresponding field.
Path path = Path(file_path_);
json output;
if (path.Exists()) {
try {
std::ifstream file(file_path_);
file >> output;
} catch (const std::exception &err) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open json file: " + file_path_ +
", please delete it and try again!");
}
}
output["time_stamp"] = time_stamp_;
std::ofstream os(file_path_, std::ios::trunc);
os << output;
os.close();
return Status::OK();
}
Status CpuSampling::SaveSamplingItervalToFile() {
// If the file is already exist, simply add the data to corresponding field.
Path path = Path(file_path_);
json output;
if (path.Exists()) {
try {
std::ifstream file(file_path_);
file >> output;
} catch (const std::exception &err) {
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open json file: " + file_path_ +
", please delete it and try again!");
}
}
output["sampling_interval"] = GlobalContext::config_manager()->monitor_sampling_interval();
std::ofstream os(file_path_, std::ios::trunc);
os << output;
os.close();
return Status::OK();
}
// Analyze profiling data and output warning messages
Status CpuSampling::Analyze() {
std::string name;
double utilization = 0;
constexpr double total_cpu_thold = 90;
constexpr double op_cpu_thold = 80;
// Keep track of specific information returned by differentn CPU sampling types
double total_utilization = 0;
double max_op_utilization = 0;
std::string max_op_name;
std::string detailed_op_cpu_message;
// Save cpu information to json file
for (auto cpu : cpu_) {
std::string extra_message;
RETURN_IF_NOT_OK(cpu->Analyze(&name, &utilization, &extra_message));
if (name == "device_info") {
total_utilization = utilization;
} else if (name != "process_info") {
max_op_utilization = utilization;
max_op_name = name;
detailed_op_cpu_message = extra_message;
}
}
if ((total_utilization < total_cpu_thold) && (max_op_utilization > op_cpu_thold)) {
MS_LOG(WARNING) << "Operator " << max_op_name << " is using " << max_op_utilization << "% CPU per thread. "
<< "This operator may benefit from increasing num_parallel_workers."
<< "Full Operator CPU utiliization for all operators: " << detailed_op_cpu_message << std::endl;
}
return Status::OK();
}
// Save profiling data to file
Status CpuSampling::SaveToFile() {
// Save time stamp to json file
RETURN_IF_NOT_OK(SaveTimeStampToFile());
// Save time stamp to json file
RETURN_IF_NOT_OK(SaveSamplingItervalToFile());
// Save cpu information to json file
for (auto cpu : cpu_) {
RETURN_IF_NOT_OK(cpu->SaveToFile(file_path_));
}
return Status::OK();
}
Status CpuSampling::Init(const std::string &dir_path, const std::string &device_id) {
file_path_ = (Path(dir_path) / Path("minddata_cpu_utilization_" + device_id + ".json")).ToString();
std::shared_ptr<DeviceCpu> device_cpu = std::make_shared<DeviceCpu>();
std::shared_ptr<OperatorCpu> operator_cpu = std::make_shared<OperatorCpu>();
std::shared_ptr<ProcessCpu> process_cpu = std::make_shared<ProcessCpu>();
cpu_.push_back(device_cpu);
cpu_.push_back(operator_cpu);
cpu_.push_back(process_cpu);
return Status::OK();
}
Status CpuSampling::ChangeFileMode() {
if (chmod(common::SafeCStr(file_path_), S_IRUSR | S_IWUSR) == -1) {
std::string err_str = "Change file mode failed," + file_path_;
return Status(StatusCode::kMDUnexpectedError, err_str);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -1,210 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CPU_SAMPLING_H
#define MINDSPORE_CCSRC_MINDDATA_DATASET_CPU_SAMPLING_H
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <nlohmann/json.hpp>
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
namespace mindspore {
namespace dataset {
class ExecutionTree;
// CPU information from /proc/stat or /proc/pid/stat file
typedef struct CpuStat_s {
uint64_t user_stat_;
uint64_t sys_stat_;
uint64_t io_stat_;
uint64_t idle_stat_;
uint64_t total_stat_;
} CpuStat;
// Cpu utilization
typedef struct CpuInfo_s {
uint8_t user_utilization_;
uint8_t sys_utilization_;
uint8_t io_utilization_;
uint8_t idle_utilization_;
} CpuUtil;
// CPU utilization of operator
typedef struct CpuOpInfo_s {
float user_utilization_;
float sys_utilization_;
int32_t op_id_;
} CpuOpUtil;
// CPU utilization of process
typedef struct CpuProcessInfo_s {
float user_utilization_;
float sys_utilization_;
} CpuProcessUtil;
// CPU stat of operator
typedef struct CpuOpStat_s {
uint64_t user_stat_;
uint64_t sys_stat_;
} CpuOpStat;
class BaseCpu {
public:
BaseCpu();
~BaseCpu() = default;
// Collect CPU information
virtual Status Collect(const ExecutionTree *tree) = 0;
virtual Status SaveToFile(const std::string &file_path) = 0;
virtual Status Analyze(std::string *name, double *utilization, std::string *extra_message) = 0;
// Get the total CPU time of device
Status GetTotalCpuTime(uint64_t *total_stat);
protected:
std::vector<CpuUtil> cpu_util_;
CpuStat pre_cpu_stat_;
static bool fetched_all_process_shared_;
static std::unordered_map<int32_t, std::vector<pid_t>> op_process_shared_;
bool fetched_all_process_;
bool pre_fetched_state_;
std::unordered_map<int32_t, std::vector<pid_t>> op_process_;
int32_t cpu_processor_num_;
};
// Collect device CPU information
class DeviceCpu : public BaseCpu {
public:
DeviceCpu() : pre_running_process_(0), pre_context_switch_count_(0), first_collect_(true) {}
~DeviceCpu() = default;
Status Collect(const ExecutionTree *tree) override;
Status SaveToFile(const std::string &file_path) override;
Status Analyze(std::string *name, double *utilization, std::string *extra_message) override;
private:
// Get CPU information, include use/sys/idle/io utilization
Status ParseCpuInfo(const std::string &str);
// Get context switch count
Status ParseCtxt(const std::string &str);
// Get running process count
Status ParseRunningProcess(const std::string &str);
std::vector<uint32_t> running_process_;
std::vector<uint64_t> context_switch_count_;
uint32_t pre_running_process_;
uint64_t pre_context_switch_count_;
bool first_collect_;
};
// Collect operator CPU information
class OperatorCpu : public BaseCpu {
public:
OperatorCpu() : first_collect_(true), pre_total_stat_(0), id_count_(0) {}
~OperatorCpu() = default;
Status Collect(const ExecutionTree *tree) override;
Status SaveToFile(const std::string &file_path) override;
// Analyze will output the name of the metric, the avg utiliization of highest
// object within the class and any extra message that would be useful for the user.
// The Higher level CPUSampling class will combine information from different classes
// to decide if warning should be output.
Status Analyze(std::string *name, double *utilization, std::string *extra_message) override;
private:
// Get cpu information, include use/sys/idle/io utilization
Status ParseCpuInfo(int32_t op_id, int64_t thread_id,
std::unordered_map<int32_t, std::unordered_map<int64_t, CpuOpStat>> *op_stat);
// Store the CPU utilization of each operator
std::vector<std::vector<CpuOpUtil>> cpu_op_util_;
bool first_collect_;
// Store the id and its corresponding threads.
std::unordered_map<int32_t, std::vector<pid_t>> op_thread_;
std::unordered_map<int32_t, std::string> op_name_;
std::unordered_map<int32_t, int32_t> op_parallel_workers_;
std::unordered_map<int32_t, std::unordered_map<int64_t, CpuOpStat>> pre_op_stat_;
uint64_t pre_total_stat_;
int32_t id_count_;
};
// Collect operator CPU information
class ProcessCpu : public BaseCpu {
public:
ProcessCpu() : first_collect_(true), pre_total_stat_(0) {}
~ProcessCpu() = default;
Status Collect(const ExecutionTree *tree) override;
Status SaveToFile(const std::string &file_path) override;
Status Analyze(std::string *name, double *utilization, std::string *extra_message) override;
private:
// Get CPU information, include use/sys/idle/io utilization
Status ParseCpuInfo();
bool first_collect_;
std::vector<CpuProcessUtil> process_util_;
uint64_t pre_total_stat_;
std::unordered_map<int64_t, CpuOpStat> pre_process_stat_;
std::vector<pid_t> process_id_;
};
// Sampling CPU information
// It support JSON serialization for external usage.
class CpuSampling : public Sampling {
using TimeStamp = std::vector<uint32_t>;
public:
explicit CpuSampling(ExecutionTree *tree) : tree_(tree) {}
~CpuSampling() = default;
// Driver function for CPU sampling.
// This function samples the CPU information of device/process/op
Status Sample() override;
std::string Name() const override { return kCpuSamplingName; }
// Save sampling data to file
// @return Status - The error code return
Status SaveToFile() override;
Status Init(const std::string &dir_path, const std::string &device_id) override;
// Change file mode after save CPU data
Status ChangeFileMode() override;
// Analyze sampling data and print message to log
Status Analyze() override;
private:
Status CollectTimeStamp();
Status SaveTimeStampToFile();
Status SaveSamplingItervalToFile();
ExecutionTree *tree_ = nullptr; // ExecutionTree pointer
std::vector<std::shared_ptr<BaseCpu>> cpu_; // CPU information of device/process/op
TimeStamp time_stamp_; // Time stamp
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CPU_SAMPLING_H

View File

@ -13,49 +13,66 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/stat.h>
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
#include <fstream>
#include <string>
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
#else
#include "mindspore/lite/src/common/log_adapter.h"
#endif
#include "minddata/dataset/util/path.h"
#include "mindspore/core/utils/ms_utils.h"
namespace mindspore {
namespace dataset {
Status DatasetIteratorTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num,
const int32_t value, const uint64_t time_stamp) {
// Format: "type extra-info batch-num value"
// type: 0: time, 1: connector size
// extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time
// if type is 1 - connector capacity
// batch-num: batch number
// value: if type is 0 - value is time(ms)
// if type is 1 - value is connector size
// Examples:
// 0 0 20 10 - The 20th batch took 10ms to get data from pipeline.
// 1 64 20 5 - Connector size is 5 when get the 20th batch.Connector capacity is 64.
std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " +
std::to_string(value) + " " + std::to_string(time_stamp);
value_.emplace_back(data);
return Status::OK();
}
constexpr int32_t CONNECTOR_CAPACITY_OFFSET = 0;
Status DatasetIteratorTracing::Init(const std::string &dir_path, const std::string &device_id) {
file_path_ = (Path(dir_path) / Path("dataset_iterator_profiling_" + device_id + ".txt")).ToString();
return Status::OK();
}
Status DatasetIteratorTracing::ChangeFileMode() {
if (value_.empty()) {
return Status::OK();
}
Status DatasetIteratorTracing::GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return {StatusCode::kMDUnexpectedError, "Dataset Iterator Tracing does not record pipeline time."};
}
if (chmod(common::SafeCStr(file_path_), S_IRUSR | S_IWUSR) == -1) {
std::string err_str = "Change file mode failed," + file_path_;
return Status(StatusCode::kMDUnexpectedError, err_str);
Status DatasetIteratorTracing::GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return {StatusCode::kMDUnexpectedError, "Dataset Iterator Tracing does not record push time."};
}
Status DatasetIteratorTracing::GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return {StatusCode::kMDUnexpectedError, "Dataset Iterator Tracing does not record batch time."};
}
Status DatasetIteratorTracing::GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntry(start_step, end_step, CONNECTOR_CAPACITY_OFFSET, result);
}
Status DatasetIteratorTracing::GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) {
std::lock_guard<std::mutex> guard(lock_);
auto total_steps = records_.size() / records_per_step_;
MS_LOG(DEBUG) << "start_step: " << start_step << " end_step: " << end_step;
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= total_steps,
"Expected start_step <= total_steps. Got start_step: " + std::to_string(start_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(end_step <= total_steps,
"Expected end_step <= total_steps. Got end_step: " + std::to_string(end_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= end_step,
"Expected start_step <= end_step. Got start_step: " + std::to_string(start_step) +
" end_step: " + std::to_string(end_step));
uint32_t total = end_step - start_step + 1;
uint32_t count = 0U;
for (auto step_num = start_step; step_num <= end_step; step_num++) {
auto idx = (step_num - 1) * records_per_step_ + CONNECTOR_CAPACITY_OFFSET;
count += static_cast<uint32_t>(records_[idx].value == 0);
}
*empty_queue_freq = static_cast<float_t>(count) / static_cast<float_t>(total);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -23,24 +23,24 @@
namespace mindspore {
namespace dataset {
constexpr int32_t RECORDS_PER_STEP_DATASET_ITERATOR = 1;
class DatasetIteratorTracing : public Tracing {
public:
// Constructor
DatasetIteratorTracing() = default;
DatasetIteratorTracing() : Tracing(RECORDS_PER_STEP_DATASET_ITERATOR) {}
// Destructor
~DatasetIteratorTracing() override = default;
// Record tracing data
// @return Status The status code returned
Status Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value,
const uint64_t time_stamp);
std::string Name() const override { return kDatasetIteratorTracingName; };
Status Init(const std::string &dir_path, const std::string &device_id) override;
Status ChangeFileMode() override;
Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) override;
};
} // namespace dataset
} // namespace mindspore

View File

@ -14,47 +14,67 @@
* limitations under the License.
*/
#include <sys/stat.h>
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include <fstream>
#include <string>
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#ifndef ENABLE_ANDROID
#include "utils/log_adapter.h"
#else
#include "mindspore/lite/src/common/log_adapter.h"
#endif
#include "minddata/dataset/util/path.h"
#include "mindspore/core/utils/ms_utils.h"
namespace mindspore {
namespace dataset {
void DeviceQueueTracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num,
const int32_t value, const uint64_t time_stamp) {
// Format: "type extra-info batch-num value"
// type: 0: time, 1: connector size
// extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time
// if type is 1 - connector capacity
// batch-num: batch number
// value: if type is 0 - value is time(ms)
// if type is 1 - value is connector size
// time-stamp: time stamp
// Examples:
// 0 0 20 10 xxx- The 20th batch took 10ms to get data from pipeline.
// 1 64 20 5 xxx- Connector size is 5 when get the 20th batch.Connector capacity is 64.
std::string data = std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " +
std::to_string(value) + " " + std::to_string(time_stamp);
value_.emplace_back(data);
}
constexpr int32_t PUSH_TIME_OFFSET = 0;
constexpr int32_t BATCH_TIME_OFFSET = 1;
constexpr int32_t PIPELINE_TIME_OFFSET = 2;
constexpr int32_t CONNECTOR_CAPACITY_OFFSET = 3;
Status DeviceQueueTracing::Init(const std::string &dir_path, const std::string &device_id) {
file_path_ = (Path(dir_path) / Path("device_queue_profiling_" + device_id + ".txt")).ToString();
return Status::OK();
}
Status DeviceQueueTracing::ChangeFileMode() {
if (value_.empty()) {
return Status::OK();
}
Status DeviceQueueTracing::GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntry(start_step, end_step, PIPELINE_TIME_OFFSET, result);
}
if (chmod(common::SafeCStr(file_path_), S_IRUSR | S_IWUSR) == -1) {
std::string err_str = "Change file mode failed," + file_path_;
return Status(StatusCode::kMDUnexpectedError, err_str);
Status DeviceQueueTracing::GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntry(start_step, end_step, PUSH_TIME_OFFSET, result);
}
Status DeviceQueueTracing::GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntry(start_step, end_step, BATCH_TIME_OFFSET, result);
}
Status DeviceQueueTracing::GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) {
return GetRecordEntry(start_step, end_step, CONNECTOR_CAPACITY_OFFSET, result);
}
Status DeviceQueueTracing::GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) {
std::lock_guard<std::mutex> guard(lock_);
auto total_steps = records_.size() / records_per_step_;
MS_LOG(DEBUG) << "start_step: " << start_step << " end_step: " << end_step;
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= total_steps,
"Expected start_step <= total_steps. Got start_step: " + std::to_string(start_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(end_step <= total_steps,
"Expected end_step <= total_steps. Got end_step: " + std::to_string(end_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= end_step,
"Expected start_step <= end_step. Got start_step: " + std::to_string(start_step) +
" end_step: " + std::to_string(end_step));
uint32_t total = end_step - start_step + 1;
uint32_t count = 0U;
for (auto step_num = start_step; step_num <= end_step; step_num++) {
auto idx = (step_num - 1) * records_per_step_ + CONNECTOR_CAPACITY_OFFSET;
count += static_cast<uint32_t>(records_[idx].value == 0);
}
*empty_queue_freq = static_cast<float_t>(count) / static_cast<float_t>(total);
return Status::OK();
}
} // namespace dataset

View File

@ -23,24 +23,24 @@
namespace mindspore {
namespace dataset {
constexpr int32_t RECORDS_PER_STEP_DEVICE_QUEUE = 4;
class DeviceQueueTracing : public Tracing {
public:
// Constructor
DeviceQueueTracing() = default;
DeviceQueueTracing() : Tracing(RECORDS_PER_STEP_DEVICE_QUEUE) {}
// Destructor
~DeviceQueueTracing() override = default;
// Record tracing data
// @return Status The status code returned
void Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value,
const uint64_t time_stamp);
std::string Name() const override { return kDeviceQueueTracingName; };
Status Init(const std::string &dir_path, const std::string &device_id) override;
Status ChangeFileMode() override;
Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) override;
Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) override;
};
} // namespace dataset
} // namespace mindspore

View File

@ -60,7 +60,6 @@ Status Monitor::operator()() {
RETURN_IF_NOT_OK(tree_consumer_->GetProfilingManager()->Analyze());
RETURN_IF_NOT_OK(tree_consumer_->GetProfilingManager()->SaveProfilingData());
RETURN_IF_NOT_OK(tree_consumer_->GetProfilingManager()->ChangeFileMode());
cfg->set_profiler_file_status(true);
return Status::OK();
}

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/perf/profiling.h"
#include <sys/stat.h>
#include <cstdlib>
#include <fstream>
#include "utils/ms_utils.h"
@ -25,9 +26,9 @@
#include "minddata/dataset/engine/perf/monitor.h"
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include "minddata/dataset/engine/perf/connector_size.h"
#include "minddata/dataset/engine/perf/connector_throughput.h"
#include "minddata/dataset/engine/perf/cpu_sampling.h"
#include "minddata/dataset/engine/perf/cpu_sampler.h"
#include "minddata/dataset/engine/perf/dataset_iterator_tracing.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/util/log_adapter.h"
namespace mindspore {
@ -50,6 +51,63 @@ Status Tracing::SaveToFile() {
return Status::OK();
}
Status Tracing::ChangeFileMode() {
if (value_.empty()) {
return Status::OK();
}
if (chmod(common::SafeCStr(file_path_), S_IRUSR | S_IWUSR) == -1) {
std::string err_str = "Change file mode failed," + file_path_;
return Status(StatusCode::kMDUnexpectedError, err_str);
}
return Status::OK();
}
void Tracing::Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value,
const uint64_t time_stamp) {
// Format: "type extra-info batch-num value"
// type: 0: time, 1: connector size
// extra-info: if type is 0 - 0: pipeline time, 1: push tdt time, 2: batch time
// if type is 1 - connector capacity
// batch-num: batch number
// value: if type is 0 - value is time(ms)
// if type is 1 - value is connector size
// time-stamp: time stamp
// Examples:
// 0 0 20 10 xxx- The 20th batch took 10ms to get data from pipeline.
// 1 64 20 5 xxx- Connector size is 5 when get the 20th batch.Connector capacity is 64.
TracingRecord record = {type, extra_info, batch_num, value, time_stamp};
std::lock_guard<std::mutex> guard(lock_);
(void)records_.emplace_back(record);
(void)value_.emplace_back(record.ToString());
}
Status Tracing::GetRecordEntry(int32_t start_step, int32_t end_step, int32_t record_offset,
std::vector<int32_t> *result) {
std::lock_guard<std::mutex> guard(lock_);
auto total_steps = records_.size() / records_per_step_;
MS_LOG(DEBUG) << "start_step: " << start_step << " end_step: " << end_step;
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= total_steps,
"Expected start_step <= total_steps. Got start_step: " + std::to_string(start_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(end_step <= total_steps,
"Expected end_step <= total_steps. Got end_step: " + std::to_string(end_step) +
" total_steps: " + std::to_string(total_steps));
CHECK_FAIL_RETURN_UNEXPECTED(start_step <= end_step,
"Expected start_step <= end_step. Got start_step: " + std::to_string(start_step) +
" end_step: " + std::to_string(end_step));
for (auto step_num = start_step; step_num <= end_step; step_num++) {
// each step has 4 entries in device queue tracing
auto idx = (step_num - 1) * records_per_step_ + record_offset;
assert(idx < records_.size());
(void)result->emplace_back(records_[idx].value);
}
return Status::OK();
}
Tracing::Tracing(int32_t records_per_step) : records_per_step_(records_per_step) {}
Status Sampling::ReadJson(nlohmann::json *output) {
RETURN_UNEXPECTED_IF_NULL(output);
Path path = Path(file_path_);
@ -134,13 +192,14 @@ Status ProfilingManager::Initialize(ExecutionTree *tree) {
std::shared_ptr<Sampling> connector_size_sampling = std::make_shared<ConnectorSize>(tree_);
RETURN_IF_NOT_OK(RegisterSamplingNode(connector_size_sampling));
std::shared_ptr<Sampling> connector_thr_sampling = std::make_shared<ConnectorThroughput>(tree_);
RETURN_IF_NOT_OK(RegisterSamplingNode(connector_thr_sampling));
#ifndef ENABLE_ANDROID
std::shared_ptr<Sampling> cpu_sampling = std::make_shared<CpuSampling>(tree_);
RETURN_IF_NOT_OK(RegisterSamplingNode(cpu_sampling));
std::shared_ptr<Sampling> cpu_sampler = std::make_shared<CpuSampler>(tree_);
RETURN_IF_NOT_OK(RegisterSamplingNode(cpu_sampler));
#endif
// can insert a correct timestamp so that we can ignore the samples that were taken
// during start up of the pipeline.
(void)epoch_end_ts_.emplace_back(0);
(void)epoch_end_step_.emplace_back(0);
return Status::OK();
}
@ -214,6 +273,7 @@ Status ProfilingManager::SaveProfilingData() {
MS_LOG(INFO) << "Save profiling data end.";
return Status::OK();
}
Status ProfilingManager::Analyze() {
if (!IsProfilingEnable()) {
return Status::OK();
@ -240,8 +300,138 @@ Status ProfilingManager::ChangeFileMode() {
return Status::OK();
}
#ifndef ENABLE_ANDROID
Status ProfilingManager::GetUserCpuUtil(int32_t epoch_num, std::vector<uint8_t> *result) {
std::shared_ptr<CpuSampler> cpu_node;
uint64_t start_ts, end_ts;
RETURN_IF_NOT_OK(PopulateCpuSamplerAPIInputs(epoch_num, &start_ts, &end_ts, &cpu_node));
return cpu_node->GetSystemUserCpuUtil(start_ts, end_ts, result);
}
Status ProfilingManager::GetSysCpuUtil(int32_t epoch_num, std::vector<uint8_t> *result) {
std::shared_ptr<CpuSampler> cpu_node;
uint64_t start_ts, end_ts;
RETURN_IF_NOT_OK(PopulateCpuSamplerAPIInputs(epoch_num, &start_ts, &end_ts, &cpu_node));
return cpu_node->GetSystemSysCpuUtil(start_ts, end_ts, result);
}
Status ProfilingManager::GetUserCpuUtil(int32_t op_id, int32_t epoch_num, std::vector<uint16_t> *result) {
std::shared_ptr<CpuSampler> cpu_node;
uint64_t start_ts, end_ts;
RETURN_IF_NOT_OK(PopulateCpuSamplerAPIInputs(epoch_num, &start_ts, &end_ts, &cpu_node));
return cpu_node->GetOpUserCpuUtil(op_id, start_ts, end_ts, result);
}
Status ProfilingManager::GetSysCpuUtil(int32_t op_id, int32_t epoch_num, std::vector<uint16_t> *result) {
std::shared_ptr<CpuSampler> cpu_node;
uint64_t start_ts, end_ts;
RETURN_IF_NOT_OK(PopulateCpuSamplerAPIInputs(epoch_num, &start_ts, &end_ts, &cpu_node));
return cpu_node->GetOpSysCpuUtil(op_id, start_ts, end_ts, result);
}
Status ProfilingManager::PopulateCpuSamplerAPIInputs(int32_t epoch_num, uint64_t *start_ts, uint64_t *end_ts,
std::shared_ptr<CpuSampler> *node) {
RETURN_IF_NOT_OK(EpochToTimeInterval(epoch_num, start_ts, end_ts));
std::shared_ptr<Sampling> sampling_node;
RETURN_IF_NOT_OK(GetSamplingNode(kCpuSamplerName, &sampling_node));
*node = std::dynamic_pointer_cast<CpuSampler>(sampling_node);
return Status::OK();
}
#endif
Status ProfilingManager::EpochToTimeInterval(int32_t epoch_num, uint64_t *start_ts, uint64_t *end_ts) {
if (epoch_num <= 0 || epoch_num >= epoch_end_ts_.size()) {
std::string err = "Epoch: " + std::to_string(epoch_num) + " is invalid.";
MS_LOG(INFO) << err;
return {StatusCode::kMDUnexpectedError, err};
}
*start_ts = epoch_end_ts_[epoch_num - 1];
*end_ts = epoch_end_ts_[epoch_num];
return Status::OK();
}
Status ProfilingManager::EpochToStepInterval(int32_t epoch_num, uint32_t *start_step, uint32_t *end_step) {
if (epoch_num <= 0 || epoch_num >= epoch_end_step_.size()) {
std::string err = "Epoch: " + std::to_string(epoch_num) + " is invalid.";
MS_LOG(INFO) << err;
return {StatusCode::kMDUnexpectedError, err};
}
*start_step = epoch_end_step_[epoch_num - 1] + 1;
*end_step = epoch_end_step_[epoch_num];
return Status::OK();
}
Status ProfilingManager::GetConnectorSize(int32_t op_id, int32_t epoch_num, std::vector<int32_t> *result) {
uint64_t start_ts, end_ts;
RETURN_IF_NOT_OK(EpochToTimeInterval(epoch_num, &start_ts, &end_ts));
std::shared_ptr<Sampling> node;
RETURN_IF_NOT_OK(GetSamplingNode(kConnectorSizeSamplingName, &node));
auto connector_node = std::dynamic_pointer_cast<ConnectorSize>(node);
return connector_node->GetOpConnectorSize(op_id, start_ts, end_ts, result);
}
Status ProfilingManager::GetPipelineTime(int32_t epoch_num, std::vector<int32_t> *result) {
uint32_t start_step, end_step;
RETURN_IF_NOT_OK(EpochToStepInterval(epoch_num, &start_step, &end_step));
std::shared_ptr<Tracing> node;
if (GetTracingNode(kDeviceQueueTracingName, &node).IsOk() ||
GetTracingNode(kDatasetIteratorTracingName, &node).IsOk()) {
return node->GetPipelineTime(start_step, end_step, result);
} else {
return {StatusCode::kMDUnexpectedError, "Cannot find appropriate tracing node"};
}
}
Status ProfilingManager::GetPushTime(int32_t epoch_num, std::vector<int32_t> *result) {
uint32_t start_step, end_step;
RETURN_IF_NOT_OK(EpochToStepInterval(epoch_num, &start_step, &end_step));
std::shared_ptr<Tracing> node;
if (GetTracingNode(kDeviceQueueTracingName, &node).IsOk() ||
GetTracingNode(kDatasetIteratorTracingName, &node).IsOk()) {
return node->GetPushTime(start_step, end_step, result);
} else {
return {StatusCode::kMDUnexpectedError, "Cannot find appropriate tracing node"};
}
}
Status ProfilingManager::GetBatchTime(int32_t epoch_num, std::vector<int32_t> *result) {
uint32_t start_step, end_step;
RETURN_IF_NOT_OK(EpochToStepInterval(epoch_num, &start_step, &end_step));
std::shared_ptr<Tracing> node;
if (GetTracingNode(kDeviceQueueTracingName, &node).IsOk() ||
GetTracingNode(kDatasetIteratorTracingName, &node).IsOk()) {
return node->GetBatchTime(start_step, end_step, result);
} else {
return {StatusCode::kMDUnexpectedError, "Cannot find appropriate tracing node"};
}
}
Status ProfilingManager::GetConnectorSize(int32_t epoch_num, std::vector<int32_t> *result) {
uint32_t start_step, end_step;
RETURN_IF_NOT_OK(EpochToStepInterval(epoch_num, &start_step, &end_step));
std::shared_ptr<Tracing> node;
if (GetTracingNode(kDeviceQueueTracingName, &node).IsOk() ||
GetTracingNode(kDatasetIteratorTracingName, &node).IsOk()) {
return node->GetConnectorSize(start_step, end_step, result);
} else {
return {StatusCode::kMDUnexpectedError, "Cannot find appropriate tracing node"};
}
}
Status ProfilingManager::GetEmptyQueueFrequency(int32_t epoch_num, float_t *result) {
uint32_t start_step, end_step;
RETURN_IF_NOT_OK(EpochToStepInterval(epoch_num, &start_step, &end_step));
std::shared_ptr<Tracing> node;
if (GetTracingNode(kDeviceQueueTracingName, &node).IsOk() ||
GetTracingNode(kDatasetIteratorTracingName, &node).IsOk()) {
return node->GetEmptyQueueFrequency(start_step, end_step, result);
} else {
return {StatusCode::kMDUnexpectedError, "Cannot find appropriate tracing node"};
}
}
void ProfilingManager::RecordEndOfEpoch(uint32_t step_num) {
MS_LOG(INFO) << "Record end of epoch. step_num: " << step_num;
MS_LOG(INFO) << "Recording end of epoch. step_num: " << step_num;
(void)epoch_end_ts_.emplace_back(ProfilingTime::GetCurMilliSecond());
(void)epoch_end_step_.emplace_back(step_num);
}

View File

@ -21,6 +21,7 @@
#include <unordered_map>
#include <memory>
#include <chrono>
#include <mutex>
#include <nlohmann/json.hpp>
#include "minddata/dataset/util/status.h"
@ -30,12 +31,12 @@ namespace dataset {
class Monitor;
class ExecutionTree;
class TreeConsumer;
class CpuSampler;
const char kDeviceQueueTracingName[] = "Device_Queue_Tracing";
const char kDatasetIteratorTracingName[] = "Dataset_Iterator_Tracing";
const char kConnectorSizeSamplingName[] = "Connector_Size_Sampling";
const char kConnectorThroughputSamplingName[] = "Connector_Throughput_Sampling";
const char kCpuSamplingName[] = "Cpu_Sampling";
const char kCpuSamplerName[] = "Cpu_Sampler";
// Profiling is a class of basic unit of profiling action
// This base class encapsulate the serialization output logic
@ -59,6 +60,7 @@ class Profiling : std::enable_shared_from_this<Profiling> {
protected:
std::string file_path_;
std::mutex lock_;
};
// Sampling is a class of profiling which generate samples periodically.
@ -72,15 +74,40 @@ class Sampling : public Profiling {
Status ReadJson(nlohmann::json *output);
};
typedef struct TracingRecord_s {
int32_t type;
int32_t extra_info;
int32_t batch_num;
int32_t value;
uint64_t ts;
std::string ToString() {
return std::to_string(type) + " " + std::to_string(extra_info) + " " + std::to_string(batch_num) + " " +
std::to_string(value) + " " + std::to_string(ts);
}
} TracingRecord;
// Tracing is class of profiling which record samples upon request.
class Tracing : public Profiling {
public:
// Tracing has minimal interface to provide flexible on data recording.
// It only includes some common routines.
Status SaveToFile();
Status SaveToFile() override;
Status ChangeFileMode() override;
virtual Status GetPipelineTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetPushTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetBatchTime(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetConnectorSize(int32_t start_step, int32_t end_step, std::vector<int32_t> *result) = 0;
virtual Status GetEmptyQueueFrequency(int32_t start_step, int32_t end_step, float_t *empty_queue_freq) = 0;
void Record(const int32_t type, const int32_t extra_info, const int32_t batch_num, const int32_t value,
const uint64_t time_stamp);
protected:
explicit Tracing(int32_t records_per_step);
const int32_t records_per_step_;
std::vector<std::string> value_;
std::vector<TracingRecord> records_;
Status GetRecordEntry(int32_t start_step, int32_t end_step, int32_t record_offset, std::vector<int32_t> *result);
};
// ProfilingManager is a class manages all profiling infrastructure
@ -135,13 +162,80 @@ class ProfilingManager {
// Analyze profile data and print warning messages
Status Analyze();
#ifndef ENABLE_ANDROID
/// \brief API to get User CPU utilization for the system
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the sampled User CPU Utilization for the entire system
/// \return Status object with the error code
Status GetUserCpuUtil(int32_t epoch_num, std::vector<uint8_t> *result);
/// \brief API to get System CPU utilization for the system
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the sampled System CPU Utilization for the entire system
/// \return Status object with the error code
Status GetSysCpuUtil(int32_t epoch_num, std::vector<uint8_t> *result);
/// \brief API to get User CPU Utilization of an MD operator
/// \param [in] op_id The id of the operator
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the sampled User CPU Utilization of the operator.
/// \return Status object with the error code
Status GetUserCpuUtil(int32_t op_id, int32_t epoch_num, std::vector<uint16_t> *result);
/// \brief API to get System CPU Utilization of an MD operator
/// \param [in] op_id The id of the operator
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the sampled System CPU Utilization of the operator.
/// \return Status object with the error code
Status GetSysCpuUtil(int32_t op_id, int32_t epoch_num, std::vector<uint16_t> *result);
#endif
/// \brief API to get the connector size of an MD operator
/// \param [in] op_id The id of the operator
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the sampled connector sizes of the operator
/// \return Status object with the error code
Status GetConnectorSize(int32_t op_id, int32_t epoch_num, std::vector<int32_t> *result);
/// \brief API to get the connector size of DatasetIterator or DeviceQueueOp
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with connector size at each step
/// \return Status object with the error code
Status GetConnectorSize(int32_t epoch_num, std::vector<int32_t> *result);
/// \brief API to get the pipeline time of batches
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the pipeline time for each step
/// \return Status object with the error code
Status GetPipelineTime(int32_t epoch_num, std::vector<int32_t> *result);
/// \brief API to get the push time of batches
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the push time for each each step
/// \return Status object with the error code
Status GetPushTime(int32_t epoch_num, std::vector<int32_t> *result);
/// \brief API to get the batch time of batches
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result A vector with the batch time for each step
/// \return Status object with the error code
Status GetBatchTime(int32_t epoch_num, std::vector<int32_t> *result);
/// \brief API to get fraction of steps that DatasetIterator or DeviceQueueOp connector was empty
/// \param [in] epoch_num The epoch number for which results are requested
/// \param [out] result The empty queue frequency
/// \return Status object with the error code
Status GetEmptyQueueFrequency(int32_t epoch_num, float_t *result);
private:
std::unique_ptr<Monitor> perf_monitor_;
bool enabled_;
std::unordered_map<std::string, std::shared_ptr<Tracing>> tracing_nodes_;
std::unordered_map<std::string, std::shared_ptr<Sampling>> sampling_nodes_;
ExecutionTree *tree_; // ExecutionTree pointer
TreeConsumer *tree_consumer_; // TreeConsumer pointer
std::string dir_path_; // where to create profiling file
std::string device_id_; // used when create profiling file,filename_device_id.suffix
std::vector<uint64_t> epoch_end_ts_; // End of epoch timestamp
std::vector<uint32_t> epoch_end_step_; // End of epoch step number
@ -155,10 +249,13 @@ class ProfilingManager {
// @return Status The status code returned
Status RegisterSamplingNode(std::shared_ptr<Sampling> node);
ExecutionTree *tree_; // ExecutionTree pointer
TreeConsumer *tree_consumer_; // TreeConsumer pointer
std::string dir_path_; // where to create profiling file
std::string device_id_; // used when create profiling file,filename_device_id.suffix
Status EpochToStepInterval(int32_t epoch_num, uint32_t *start_step, uint32_t *end_step);
// get start and ending timestamp of an epoch
Status EpochToTimeInterval(int32_t epoch_num, uint64_t *start_ts, uint64_t *end_ts);
#ifndef ENABLE_ANDROID
Status PopulateCpuSamplerAPIInputs(int32_t epoch_num, uint64_t *start_ts, uint64_t *end_ts,
std::shared_ptr<CpuSampler> *node);
#endif
};
enum ProfilingType { TIME, CONNECTOR_DEPTH };

View File

@ -257,8 +257,7 @@ Status TreeAdapter::GetNext(TensorRow *row) {
cur_batch_num_++;
cur_connector_size_ = tree_->root()->ConnectorSize();
cur_connector_capacity_ = tree_->root()->ConnectorCapacity();
RETURN_IF_NOT_OK(
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_, end_time));
tracing_->Record(CONNECTOR_DEPTH, cur_connector_capacity_, cur_batch_num_, cur_connector_size_, end_time);
}
#endif
return Status::OK();

View File

@ -180,7 +180,6 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/engine/perf/monitor.cc
${MINDDATA_DIR}/engine/perf/device_queue_tracing.cc
${MINDDATA_DIR}/engine/perf/connector_size.cc
${MINDDATA_DIR}/engine/perf/connector_throughput.cc
${MINDDATA_DIR}/engine/perf/dataset_iterator_tracing.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/sampler.cc
${MINDDATA_DIR}/engine/datasetops/source/sampler/subset_sampler.cc

View File

@ -82,8 +82,7 @@ def confirm_cpuutil(num_pipeline_ops, cpu_uti_file):
with open(cpu_uti_file) as file1:
data = json.load(file1)
op_info = data["op_info"]
# Confirm <num_pipeline_ops>+1 ops in CPU util file (including op_id=-1 for monitor thread)
assert len(op_info) == num_pipeline_ops + 1
assert len(op_info) == num_pipeline_ops
def confirm_ops_in_pipeline(num_ops, op_list, pipeline_file):
@ -176,7 +175,6 @@ def test_profiling_complex_pipeline():
if op_info[i]["op_type"] != "ZipOp":
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
else:
# Note: Zip is an inline op and hence does not have metrics information
assert op_info[i]["metrics"] is None
@ -243,7 +241,6 @@ def test_profiling_inline_ops_pipeline1():
else:
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
# Confirm CPU util JSON file content, when 4 ops are in the pipeline JSON file
confirm_cpuutil(4, cpu_util_file)
@ -294,7 +291,6 @@ def test_profiling_inline_ops_pipeline2():
else:
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
# Confirm CPU util JSON file content, when 5 ops are in the pipeline JSON file
confirm_cpuutil(5, cpu_util_file)
@ -384,7 +380,6 @@ def test_profiling_basic_pipeline():
else:
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
# Confirm CPU util JSON file content, when 5 ops are in the pipeline JSON file
confirm_cpuutil(5, cpu_util_file)
@ -441,7 +436,6 @@ def test_profiling_cifar10_pipeline():
else:
assert "size" in op_info[i]["metrics"]["output_queue"]
assert "length" in op_info[i]["metrics"]["output_queue"]
assert "throughput" in op_info[i]["metrics"]["output_queue"]
# Confirm CPU util JSON file content, when 5 ops are in the pipeline JSON file
confirm_cpuutil(5, cpu_util_file)