!12567 optimize rdr mainly to reduce the redundant code

From: @luopengting
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-01 16:29:30 +08:00 committed by Gitee
commit e5aedcca47
14 changed files with 70 additions and 82 deletions

View File

@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
return;
}
has_rdr_setting_ = true;
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
if (rdr_enable.has_value()) {
ParseRdrEnable(**rdr_enable);

View File

@ -27,11 +27,14 @@ class EnvConfigParser {
public:
static EnvConfigParser &GetInstance() {
static EnvConfigParser instance;
instance.Parse();
return instance;
}
void Parse();
std::string config_path() const { return config_file_; }
bool has_rdr_setting() const { return has_rdr_setting_; }
bool rdr_enabled() const { return rdr_enabled_; }
std::string rdr_path() const { return rdr_path_; }
@ -42,7 +45,9 @@ class EnvConfigParser {
std::mutex lock_;
std::string config_file_{""};
bool already_parsed_{false};
bool rdr_enabled_{false};
bool has_rdr_setting_{false};
std::string rdr_path_{"./rdr/"};
std::string GetIfstreamString(const std::ifstream &ifstream);

View File

@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
if (filename_.empty()) {
filename_ = module_ + "_" + tag_;
filename_ = module_ + delimiter_ + tag_;
if (!suffix.empty()) {
filename_ += "_" + suffix;
filename_ += delimiter_ + suffix;
}
filename_ += "_" + timestamp_;
filename_ += delimiter_ + timestamp_;
} else if (!suffix.empty()) {
filename_ += "_" + suffix;
filename_ += delimiter_ + suffix;
}
std::string file_path = directory_ + filename_;
auto realpath = Common::GetRealPath(file_path);

View File

@ -65,6 +65,7 @@ class BaseRecorder {
void SetDirectory(const std::string &directory);
void SetFilename(const std::string &filename);
void SetModule(const std::string &module) { module_ = module; }
virtual void Export() {}
protected:
@ -73,6 +74,7 @@ class BaseRecorder {
std::string directory_;
std::string filename_;
std::string timestamp_; // year,month,day,hour,minute,second
std::string delimiter_{"."};
};
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
} // namespace mindspore

View File

@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
} // namespace
void GraphExecOrderRecorder::Export() {
auto realpath = GetFileRealPath();
auto realpath = GetFileRealPath(std::to_string(graph_id_));
if (!realpath.has_value()) {
return;
}
std::string real_file_path = realpath.value() + std::to_string(graph_id_);
std::string real_file_path = realpath.value() + ".txt";
DumpGraphExeOrder(real_file_path, exec_order_);
}
} // namespace mindspore

View File

@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
GraphExecOrderRecorder(const std::string &module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id)
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
void SetModule(const std::string &module) { module_ = module; }
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
virtual void Export();

View File

@ -65,6 +65,7 @@ void GraphRecorder::Export() {
}
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
auto tmp_realpath = GetFileRealPath(suffix);
if (!tmp_realpath.has_value()) {
return;
}

View File

@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
const std::string &file_type)
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
~GraphRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
void SetDumpFlag(bool full_name) { full_name_ = full_name; }

View File

@ -21,7 +21,35 @@
#include "mindspore/core/ir/func_graph.h"
namespace mindspore {
void RecorderManager::UpdateRdrEnable() {
static bool updated = false;
if (updated) {
return;
}
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
rdr_enable_ = config_parser.rdr_enabled();
if (config_parser.has_rdr_setting()) {
#ifdef __linux__
if (!rdr_enable_) {
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
<< "' if you want to use RDR.";
}
#else
if (rdr_enable_) {
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
}
rdr_enable_ = false;
#endif
}
updated = true;
}
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
if (!rdr_enable_) {
return false;
}
if (recorder == nullptr) {
MS_LOG(ERROR) << "register recorder module with nullptr.";
return false;
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
}
void RecorderManager::TriggerAll() {
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance();
config_parser_ptr.Parse();
if (!config_parser_ptr.rdr_enabled()) {
MS_LOG(INFO) << "RDR is not enable.";
if (!rdr_enable_) {
return;
}

View File

@ -31,9 +31,12 @@ class RecorderManager {
public:
static RecorderManager &Instance() {
static RecorderManager manager;
manager.UpdateRdrEnable();
return manager;
}
void UpdateRdrEnable();
bool RdrEnable() const { return rdr_enable_; }
bool RecordObject(const BaseRecorderPtr &recorder);
void TriggerAll();
void ClearAll();
@ -42,6 +45,8 @@ class RecorderManager {
RecorderManager() {}
~RecorderManager() {}
bool rdr_enable_{false};
mutable std::mutex mtx_;
// module, BaserRecorderPtrList
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;

View File

@ -65,6 +65,9 @@ namespace RDR {
#ifdef ENABLE_D
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
TaskDebugInfoRecorderPtr task_debug_info_recorder =
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
}
#endif // ENABLE_D
#ifdef __linux__
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
graph_recorder->SetDumpFlag(full_name);
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
GraphExecOrderRecorderPtr graph_exec_order_recorder =
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
string_recorder->SetFilename(filename);
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
return false;
}
std::string submodule_name = std::string(GetSubModuleName(module));
StreamExecOrderRecorderPtr stream_exec_order_recorder =
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
#else
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
const std::string &file_type) {
static bool already_printed = false;
std::string submodule_name = std::string(GetSubModuleName(module));
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
return false;
}
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
const std::vector<CNodePtr> &exec_order) {
static bool already_printed = false;
if (already_printed) {
return false;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
return false;
}
void TriggerAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
void ClearAll() {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The RDR presently only support linux os.";
}
#endif // __linux__
} // namespace RDR
} // namespace mindspore

View File

@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
exec_order_.push_back(std::move(exec_node_ptr));
}
}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export();
private:

View File

@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
StringRecorder() : BaseRecorder() {}
StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
const std::string &filename)
: BaseRecorder(module, tag), data_(data), filename_(filename) {}
: BaseRecorder(module, tag), data_(data) {
SetFilename(filename);
}
~StringRecorder() {}
void SetModule(const std::string &module) { module_ = module; }
virtual void Export();
private:
std::string data_;
std::string filename_;
};
using StringRecorderPtr = std::shared_ptr<StringRecorder>;
} // namespace mindspore

View File

@ -258,8 +258,8 @@ class _Context:
def set_env_config_path(self, env_config_path):
"""Check and set env_config_path."""
if not self._context_handle.enable_dump_ir():
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR "
"and recompile source to enable it.")
raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
"with '-D on' and recompile source.")
env_config_path = os.path.realpath(env_config_path)
if not os.path.isfile(env_config_path):
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)