Enhance the perfomance of A+M dump: Parallelize ConverFormatForTensorAndDump and refactoring

This commit is contained in:
TinaMengtingZhang 2022-02-07 16:39:05 -05:00
parent b3c5943bf8
commit 4a8d3defe7
5 changed files with 159 additions and 89 deletions

View File

@ -635,6 +635,35 @@ void E2eDump::DumpParametersData(uint32_t rank_id, const Debugger *debugger) {
} }
#ifdef ENABLE_D #ifdef ENABLE_D
template <typename T>
dump_data_t ParseAttrsFromDumpData(const std::string &dump_path, char *data_ptr, const T &tensor, const std::string &io,
uint32_t slot) {
// get data type
auto iter_dtype = kDataTypetoMSTypeMap.find(tensor.data_type());
if (iter_dtype == kDataTypetoMSTypeMap.end()) {
MS_LOG(INFO) << "Unsupported data type for tensor " << dump_path << ": unknown(" << tensor.data_type() << ")";
return dump_data_t{};
}
auto data_type = iter_dtype->second;
// get format
auto iter_fmt = kFormatToStringMap.find(tensor.format());
if (iter_fmt == kFormatToStringMap.end()) {
MS_LOG(INFO) << "Unsupported tensor format for tensor " << dump_path << ": unknown(" << tensor.format() << ")";
return dump_data_t{};
}
std::string device_format = iter_fmt->second;
// get shape
ShapeVector shape_d;
(void)std::transform(tensor.shape().dim().begin(), tensor.shape().dim().end(), std::back_inserter(shape_d),
SizeToLong);
ShapeVector shape_to;
(void)std::transform(tensor.original_shape().dim().begin(), tensor.original_shape().dim().end(),
std::back_inserter(shape_to), SizeToLong);
// get size and sub_format
size_t data_size = (size_t)tensor.size();
int32_t sub_format = tensor.sub_format();
return dump_data_t{dump_path, data_ptr, data_type, device_format, shape_d, shape_to, data_size, sub_format, io, slot};
}
/* /*
* Feature group: Dump. * Feature group: Dump.
* Target device group: Ascend. * Target device group: Ascend.
@ -644,15 +673,13 @@ void E2eDump::DumpParametersData(uint32_t rank_id, const Debugger *debugger) {
*/ */
void E2eDump::DumpTensorToFile(const std::string &dump_path, const debugger::dump::DumpData &dump_data, void E2eDump::DumpTensorToFile(const std::string &dump_path, const debugger::dump::DumpData &dump_data,
char *data_ptr) { char *data_ptr) {
std::vector<dump_data_t> dump_tensor_vec;
// dump input tensors // dump input tensors
std::vector<debugger::dump::OpInput> input_tensors(dump_data.input().begin(), dump_data.input().end()); std::vector<debugger::dump::OpInput> input_tensors(dump_data.input().begin(), dump_data.input().end());
uint64_t offset = 0; uint64_t offset = 0;
for (uint32_t slot = 0; slot < input_tensors.size(); slot++) { for (uint32_t slot = 0; slot < input_tensors.size(); slot++) {
auto in_tensor = input_tensors[slot]; auto in_tensor = input_tensors[slot];
auto succ = ConvertFormatForTensorAndDump(dump_path, in_tensor, data_ptr + offset, "input", slot); dump_tensor_vec.push_back(ParseAttrsFromDumpData(dump_path, data_ptr + offset, in_tensor, "input", slot));
if (!succ) {
MS_LOG(INFO) << "Failed to convert format for tensor " << dump_path << ".input." << slot;
}
offset += in_tensor.size(); offset += in_tensor.size();
} }
@ -660,12 +687,44 @@ void E2eDump::DumpTensorToFile(const std::string &dump_path, const debugger::dum
std::vector<debugger::dump::OpOutput> output_tensors(dump_data.output().begin(), dump_data.output().end()); std::vector<debugger::dump::OpOutput> output_tensors(dump_data.output().begin(), dump_data.output().end());
for (uint32_t slot = 0; slot < output_tensors.size(); slot++) { for (uint32_t slot = 0; slot < output_tensors.size(); slot++) {
auto out_tensor = output_tensors[slot]; auto out_tensor = output_tensors[slot];
auto succ = ConvertFormatForTensorAndDump(dump_path, out_tensor, data_ptr + offset, "output", slot); dump_tensor_vec.push_back(ParseAttrsFromDumpData(dump_path, data_ptr + offset, out_tensor, "output", slot));
if (!succ) {
MS_LOG(INFO) << "Failed to convert format for tensor " << dump_path << ".output." << slot;
}
offset += out_tensor.size(); offset += out_tensor.size();
} }
// assign slot conversion task to different thread.
if (dump_tensor_vec.empty()) {
return;
}
auto default_num_workers = std::max<uint32_t>(1, std::thread::hardware_concurrency() / 4);
auto num_threads = std::min<uint32_t>(default_num_workers, dump_tensor_vec.size());
uint32_t task_size = dump_tensor_vec.size() / num_threads;
uint32_t remainder = dump_tensor_vec.size() % num_threads;
std::vector<std::thread> threads;
threads.reserve(num_threads);
MS_LOG(INFO) << "Number of threads used for A+M dump: " << num_threads;
for (size_t t = 0; t < threads.capacity(); t++) {
uint32_t start_idx = t * task_size;
uint32_t end_idx = start_idx + task_size - 1;
if (t == num_threads - 1) {
end_idx += remainder;
}
threads.emplace_back(std::thread(&E2eDump::ConvertFormatForTensors, std::ref(dump_tensor_vec), start_idx, end_idx));
}
for (size_t t = 0; t < threads.capacity(); t++) {
threads[t].join();
}
}
void E2eDump::ConvertFormatForTensors(const std::vector<dump_data_t> &dump_tensor_vec, uint32_t start_idx,
uint32_t end_idx) {
for (uint32_t idx = start_idx; idx <= end_idx; idx++) {
auto succ = ConvertFormatForTensorAndDump(dump_tensor_vec[idx]);
if (!succ) {
MS_LOG(INFO) << "Failed to convert format for tensor " << dump_tensor_vec[idx].dump_file_path << "."
<< dump_tensor_vec[idx].in_out_str << "." << dump_tensor_vec[idx].slot;
}
}
} }
/* /*
@ -674,13 +733,12 @@ void E2eDump::DumpTensorToFile(const std::string &dump_path, const debugger::dum
* Runtime category: Old runtime, MindRT. * Runtime category: Old runtime, MindRT.
* Description: It serves for A+M dump. Save statistic of the tensor data into dump path as configured. * Description: It serves for A+M dump. Save statistic of the tensor data into dump path as configured.
*/ */
template <typename T> bool DumpTensorStatsIfNeeded(const dump_data_t &dump_tensor_info, char *data_ptr) {
bool DumpTensorStatsIfNeeded(const std::string &dump_path, const T &tensor, char *data_ptr, const std::string &io,
uint32_t slot, const ShapeVector &shape, TypeId type) {
// dump_path: dump_dir/op_type.op_name.task_id.stream_id.timestamp // dump_path: dump_dir/op_type.op_name.task_id.stream_id.timestamp
if (!DumpJsonParser::GetInstance().IsStatisticDump()) { if (!DumpJsonParser::GetInstance().IsStatisticDump()) {
return true; return true;
} }
std::string dump_path = dump_tensor_info.dump_file_path;
size_t pos = dump_path.rfind("/"); size_t pos = dump_path.rfind("/");
std::string file_name = dump_path.substr(pos + 1); std::string file_name = dump_path.substr(pos + 1);
size_t first_dot = file_name.find("."); size_t first_dot = file_name.find(".");
@ -697,15 +755,17 @@ bool DumpTensorStatsIfNeeded(const std::string &dump_path, const T &tensor, char
std::string task_id = file_name.substr(second_dot + 1, third_dot - second_dot - 1); std::string task_id = file_name.substr(second_dot + 1, third_dot - second_dot - 1);
std::string stream_id = file_name.substr(third_dot + 1, fourth_dot - third_dot - 1); std::string stream_id = file_name.substr(third_dot + 1, fourth_dot - third_dot - 1);
std::string timestamp = file_name.substr(fourth_dot + 1); std::string timestamp = file_name.substr(fourth_dot + 1);
TensorStatDump stat_dump(op_type, op_name, task_id, stream_id, timestamp, io, slot, slot); TensorStatDump stat_dump(op_type, op_name, task_id, stream_id, timestamp, dump_tensor_info.in_out_str,
dump_tensor_info.slot, dump_tensor_info.slot);
std::shared_ptr<TensorData> data = std::make_shared<TensorData>(); std::shared_ptr<TensorData> data = std::make_shared<TensorData>();
if (type <= TypeId::kNumberTypeBegin || type >= TypeId::kNumberTypeComplex64) { if (dump_tensor_info.data_type <= TypeId::kNumberTypeBegin ||
dump_tensor_info.data_type >= TypeId::kNumberTypeComplex64) {
MS_LOG(ERROR) << "Data type of operator " << file_name << " is not supported by statistic dump"; MS_LOG(ERROR) << "Data type of operator " << file_name << " is not supported by statistic dump";
return false; return false;
} }
data->SetType(type); data->SetType(dump_tensor_info.data_type);
data->SetByteSize((size_t)tensor.size()); data->SetByteSize(dump_tensor_info.data_size);
data->SetShape(shape); data->SetShape(dump_tensor_info.host_shape);
data->SetDataPtr(data_ptr); data->SetDataPtr(data_ptr);
return stat_dump.DumpTensorStatsToFile(dump_path.substr(0, pos), data); return stat_dump.DumpTensorStatsToFile(dump_path.substr(0, pos), data);
} }
@ -717,45 +777,19 @@ bool DumpTensorStatsIfNeeded(const std::string &dump_path, const T &tensor, char
* Description: It serves for A+M dump. Parse each attributes in Dumpdata proto object from device format to mindspore * Description: It serves for A+M dump. Parse each attributes in Dumpdata proto object from device format to mindspore
* supported format and save tensor data or statistic as configured. * supported format and save tensor data or statistic as configured.
*/ */
template <typename T> bool E2eDump::ConvertFormatForTensorAndDump(const dump_data_t &dump_tensor_info) {
bool E2eDump::ConvertFormatForTensorAndDump(std::string dump_path, const T &tensor, char *data_ptr,
const std::string &io, uint32_t slot) {
// dump_path: dump_dir/op_type.op_name.task_id.stream_id.timestamp // dump_path: dump_dir/op_type.op_name.task_id.stream_id.timestamp
std::ostringstream dump_path_ss; std::ostringstream dump_path_ss;
dump_path_ss << dump_path << "." << io << "." << slot << "."; dump_path_ss << dump_tensor_info.dump_file_path << "." << dump_tensor_info.in_out_str << "." << dump_tensor_info.slot
<< ".";
std::string dump_path_slot = dump_path_ss.str(); std::string dump_path_slot = dump_path_ss.str();
// get format
auto iter_fmt = kFormatToStringMap.find(tensor.format());
if (iter_fmt == kFormatToStringMap.end()) {
MS_LOG(INFO) << "Unsupported tensor format for tensor " << dump_path << ": unknown(" << tensor.format() << ")";
return false;
}
std::string device_format = iter_fmt->second;
// get data type
auto iter_dtype = kDataTypetoMSTypeMap.find(tensor.data_type());
if (iter_dtype == kDataTypetoMSTypeMap.end()) {
MS_LOG(INFO) << "Unsupported data type for tensor " << dump_path << ": unknown(" << tensor.data_type() << ")";
return false;
}
auto src_type = iter_dtype->second;
// get host shape
std::vector<size_t> device_shape;
(void)std::copy(tensor.shape().dim().begin(), tensor.shape().dim().end(), std::back_inserter(device_shape));
ShapeVector shape_d;
(void)std::transform(device_shape.begin(), device_shape.end(), std::back_inserter(shape_d), SizeToLong);
std::vector<size_t> host_shape;
(void)std::copy(tensor.original_shape().dim().begin(), tensor.original_shape().dim().end(),
std::back_inserter(host_shape));
ShapeVector shape_to;
(void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape_to), SizeToLong);
size_t data_size = (size_t)tensor.size();
bool trans_success = false; bool trans_success = false;
auto trans_buf = std::vector<uint8_t>(data_size); auto trans_buf = std::vector<uint8_t>(dump_tensor_info.data_size);
// convert format to host format. It can be either NCHW or ND (non 4-dimemsions). // convert format to host format. It can be either NCHW or ND (non 4-dimemsions).
const uint8_t kNumFourDim = 4; const uint8_t kNumFourDim = 4;
std::string host_format; std::string host_format;
if (host_shape.size() == kNumFourDim) { std::string device_format = dump_tensor_info.format;
if (dump_tensor_info.host_shape.size() == kNumFourDim) {
host_format = kOpFormat_NCHW; host_format = kOpFormat_NCHW;
} else { } else {
host_format = kOpFormat_ND; host_format = kOpFormat_ND;
@ -766,8 +800,14 @@ bool E2eDump::ConvertFormatForTensorAndDump(std::string dump_path, const T &tens
MS_LOG(INFO) << "Do not support convert from format " << device_format << " to " << host_format << " for tensor " MS_LOG(INFO) << "Do not support convert from format " << device_format << " to " << host_format << " for tensor "
<< dump_path_slot; << dump_path_slot;
} else { } else {
const trans::FormatArgs format_args{data_ptr, data_size, host_format, device_format, shape_to, shape_d, src_type}; const trans::FormatArgs format_args{dump_tensor_info.data_ptr,
auto group = tensor.sub_format() > 1 ? tensor.sub_format() : 1; dump_tensor_info.data_size,
host_format,
device_format,
dump_tensor_info.host_shape,
dump_tensor_info.device_shape,
dump_tensor_info.data_type};
auto group = dump_tensor_info.sub_format > 1 ? dump_tensor_info.sub_format : 1;
trans_success = trans::TransFormatFromDeviceToHost(format_args, trans_buf.data(), group); trans_success = trans::TransFormatFromDeviceToHost(format_args, trans_buf.data(), group);
if (!trans_success) { if (!trans_success) {
MS_LOG(ERROR) << "Trans format failed."; MS_LOG(ERROR) << "Trans format failed.";
@ -777,19 +817,21 @@ bool E2eDump::ConvertFormatForTensorAndDump(std::string dump_path, const T &tens
// dump tensor data into npy file // dump tensor data into npy file
bool dump_success = true; bool dump_success = true;
if (trans_success) { if (trans_success) {
dump_success = DumpTensorStatsIfNeeded(dump_path, tensor, reinterpret_cast<char *>(trans_buf.data()), io, slot, dump_success = DumpTensorStatsIfNeeded(dump_tensor_info, reinterpret_cast<char *>(trans_buf.data()));
shape_to, src_type);
if (DumpJsonParser::GetInstance().IsTensorDump()) { if (DumpJsonParser::GetInstance().IsTensorDump()) {
dump_path_slot += host_format; dump_path_slot += host_format;
dump_success = dump_success = DumpJsonParser::DumpToFile(dump_path_slot, trans_buf.data(), dump_tensor_info.data_size,
DumpJsonParser::DumpToFile(dump_path_slot, trans_buf.data(), data_size, shape_to, src_type) && dump_success; dump_tensor_info.host_shape, dump_tensor_info.data_type) &&
dump_success;
} }
} else { } else {
dump_success = DumpTensorStatsIfNeeded(dump_path, tensor, data_ptr, io, slot, shape_to, src_type); dump_success = DumpTensorStatsIfNeeded(dump_tensor_info, dump_tensor_info.data_ptr);
if (DumpJsonParser::GetInstance().IsTensorDump()) { if (DumpJsonParser::GetInstance().IsTensorDump()) {
dump_path_slot += device_format; dump_path_slot += device_format;
dump_success = dump_success = DumpJsonParser::DumpToFile(dump_path_slot, dump_tensor_info.data_ptr, dump_tensor_info.data_size,
DumpJsonParser::DumpToFile(dump_path_slot, data_ptr, data_size, shape_to, src_type) && dump_success; dump_tensor_info.host_shape, dump_tensor_info.data_type) &&
dump_success;
} }
} }
return dump_success; return dump_success;

View File

@ -20,6 +20,7 @@
#include <dirent.h> #include <dirent.h>
#include <map> #include <map>
#include <string> #include <string>
#include <vector>
#include "backend/common/session/kernel_graph.h" #include "backend/common/session/kernel_graph.h"
#include "runtime/device/device_address.h" #include "runtime/device/device_address.h"
@ -34,6 +35,19 @@ using mindspore::kernel::KernelLaunchInfo;
class Debugger; class Debugger;
#endif #endif
namespace mindspore { namespace mindspore {
struct dump_data_t {
std::string dump_file_path;
char *data_ptr;
mindspore::TypeId data_type;
std::string format;
ShapeVector device_shape;
ShapeVector host_shape;
size_t data_size;
int32_t sub_format;
std::string in_out_str;
uint32_t slot;
};
class E2eDump { class E2eDump {
public: public:
E2eDump() = default; E2eDump() = default;
@ -97,9 +111,10 @@ class E2eDump {
#ifdef ENABLE_D #ifdef ENABLE_D
static nlohmann::json ParseOverflowInfo(char *data_ptr); static nlohmann::json ParseOverflowInfo(char *data_ptr);
template <typename T> static bool ConvertFormatForTensorAndDump(const dump_data_t &dump_tensor_info);
static bool ConvertFormatForTensorAndDump(std::string dump_path, const T &tensor, char *data_ptr,
const std::string &io, uint32_t slot); static void ConvertFormatForTensors(const std::vector<dump_data_t> &dump_tensor_vec, uint32_t start_idx,
uint32_t end_idx);
#endif #endif
inline static unsigned int starting_graph_id = INT32_MAX; inline static unsigned int starting_graph_id = INT32_MAX;

View File

@ -52,28 +52,35 @@ bool CsvWriter::OpenFile(const std::string &path, const std::string &header) {
} }
// try to open file // try to open file
std::string file_path_value = file_path.value(); std::string file_path_value = file_path.value();
bool first_time_opening = file_path_str_ != path; {
ChangeFileMode(file_path_value, S_IWUSR); std::lock_guard<std::mutex> lock(dump_csv_lock_);
if (first_time_opening) { if (file_.is_open()) {
// remove any possible output from previous runs return true;
file_.open(file_path_value, std::ios::out | std::ios::trunc | std::ios::binary); }
} else { bool first_time_opening = file_path_str_ != path;
file_.open(file_path_value, std::ios::out | std::ios::app | std::ios::binary); ChangeFileMode(file_path_value, S_IWUSR);
if (first_time_opening) {
// remove any possible output from previous runs
file_.open(file_path_value, std::ios::out | std::ios::trunc | std::ios::binary);
} else {
file_.open(file_path_value, std::ios::out | std::ios::app | std::ios::binary);
}
if (!file_.is_open()) {
MS_LOG(WARNING) << "Open file " << file_path_value << " failed." << ErrnoToString(errno);
return false;
}
if (first_time_opening) {
file_ << header;
(void)file_.flush();
file_path_str_ = path;
}
MS_LOG(INFO) << "Opened file: " << file_path_value;
} }
if (!file_.is_open()) {
MS_LOG(WARNING) << "Open file " << file_path_value << " failed." << ErrnoToString(errno);
return false;
}
if (first_time_opening) {
file_ << header;
(void)file_.flush();
file_path_str_ = path;
}
MS_LOG(INFO) << "Opened file: " << file_path_value;
return true; return true;
} }
void CsvWriter::CloseFile() noexcept { void CsvWriter::CloseFile() noexcept {
std::lock_guard<std::mutex> lock(dump_csv_lock_);
if (file_.is_open()) { if (file_.is_open()) {
file_.close(); file_.close();
ChangeFileMode(file_path_str_, S_IRUSR); ChangeFileMode(file_path_str_, S_IRUSR);
@ -182,6 +189,7 @@ bool TensorStatDump::DumpTensorStatsToFile(const std::string &dump_path, const s
} }
shape << ")\""; shape << ")\"";
CsvWriter &csv = CsvWriter::GetInstance(); CsvWriter &csv = CsvWriter::GetInstance();
csv.Lock();
csv.WriteToCsv(op_type_); csv.WriteToCsv(op_type_);
csv.WriteToCsv(op_name_); csv.WriteToCsv(op_name_);
csv.WriteToCsv(task_id_); csv.WriteToCsv(task_id_);
@ -208,6 +216,7 @@ bool TensorStatDump::DumpTensorStatsToFile(const std::string &dump_path, const s
csv.WriteToCsv(stat.neg_inf_count); csv.WriteToCsv(stat.neg_inf_count);
csv.WriteToCsv(stat.pos_inf_count); csv.WriteToCsv(stat.pos_inf_count);
csv.WriteToCsv(stat.zero_count, true); csv.WriteToCsv(stat.zero_count, true);
csv.Unlock();
return true; return true;
} }
} // namespace mindspore } // namespace mindspore

View File

@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <string> #include <string>
#include <fstream> #include <fstream>
#include <mutex>
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
@ -40,12 +41,15 @@ class CsvWriter {
void CloseFile() noexcept; void CloseFile() noexcept;
template <typename T> template <typename T>
void WriteToCsv(const T &val, bool end_line = false); void WriteToCsv(const T &val, bool end_line = false);
void Lock() { dump_csv_lock_.lock(); }
void Unlock() { dump_csv_lock_.unlock(); }
private: private:
const std::string kSeparator = ","; const std::string kSeparator = ",";
const std::string kEndLine = "\n"; const std::string kEndLine = "\n";
std::ofstream file_; std::ofstream file_;
std::string file_path_str_ = ""; std::string file_path_str_ = "";
std::mutex dump_csv_lock_;
}; };
class TensorStatDump { class TensorStatDump {

View File

@ -412,18 +412,18 @@ def check_statistic_dump(dump_file_path):
with open(real_path) as f: with open(real_path) as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
stats = list(reader) stats = list(reader)
input1 = stats[0] num_tensors = len(stats)
assert input1['IO'] == 'input' assert num_tensors == 3
assert input1['Min Value'] == '1' for tensor in stats:
assert input1['Max Value'] == '6' if (tensor['IO'] == 'input' and tensor['Slot'] == 0):
input2 = stats[1] assert tensor['Min Value'] == '1'
assert input2['IO'] == 'input' assert tensor['Max Value'] == '6'
assert input2['Min Value'] == '7' elif (tensor['IO'] == 'input' and tensor['Slot'] == 1):
assert input2['Max Value'] == '12' assert tensor['Min Value'] == '7'
output = stats[2] assert tensor['Max Value'] == '12'
assert output['IO'] == 'output' elif (tensor['IO'] == 'output' and tensor['Slot'] == 0):
assert output['Min Value'] == '8' assert tensor['Min Value'] == '8'
assert output['Max Value'] == '18' assert tensor['Max Value'] == '18'
def check_data_dump(dump_file_path): def check_data_dump(dump_file_path):
output_name = "Add.Add-op*.output.0.*.npy" output_name = "Add.Add-op*.output.0.*.npy"