forked from mindspore-Ecosystem/mindspore
!12567 optimize rdr mainly to reduce the redundant code
From: @luopengting Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e5aedcca47
|
@ -104,6 +104,8 @@ void EnvConfigParser::ParseRdrSetting(const nlohmann::json &content) {
|
|||
return;
|
||||
}
|
||||
|
||||
has_rdr_setting_ = true;
|
||||
|
||||
auto rdr_enable = CheckJsonKeyExist(*rdr_setting, kRdrSettings, kEnable);
|
||||
if (rdr_enable.has_value()) {
|
||||
ParseRdrEnable(**rdr_enable);
|
||||
|
|
|
@ -27,11 +27,14 @@ class EnvConfigParser {
|
|||
public:
|
||||
static EnvConfigParser &GetInstance() {
|
||||
static EnvConfigParser instance;
|
||||
instance.Parse();
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Parse();
|
||||
std::string config_path() const { return config_file_; }
|
||||
|
||||
bool has_rdr_setting() const { return has_rdr_setting_; }
|
||||
bool rdr_enabled() const { return rdr_enabled_; }
|
||||
std::string rdr_path() const { return rdr_path_; }
|
||||
|
||||
|
@ -42,7 +45,9 @@ class EnvConfigParser {
|
|||
std::mutex lock_;
|
||||
std::string config_file_{""};
|
||||
bool already_parsed_{false};
|
||||
|
||||
bool rdr_enabled_{false};
|
||||
bool has_rdr_setting_{false};
|
||||
std::string rdr_path_{"./rdr/"};
|
||||
|
||||
std::string GetIfstreamString(const std::ifstream &ifstream);
|
||||
|
|
|
@ -39,13 +39,13 @@ void BaseRecorder::SetFilename(const std::string &filename) {
|
|||
|
||||
std::optional<std::string> BaseRecorder::GetFileRealPath(const std::string &suffix) {
|
||||
if (filename_.empty()) {
|
||||
filename_ = module_ + "_" + tag_;
|
||||
filename_ = module_ + delimiter_ + tag_;
|
||||
if (!suffix.empty()) {
|
||||
filename_ += "_" + suffix;
|
||||
filename_ += delimiter_ + suffix;
|
||||
}
|
||||
filename_ += "_" + timestamp_;
|
||||
filename_ += delimiter_ + timestamp_;
|
||||
} else if (!suffix.empty()) {
|
||||
filename_ += "_" + suffix;
|
||||
filename_ += delimiter_ + suffix;
|
||||
}
|
||||
std::string file_path = directory_ + filename_;
|
||||
auto realpath = Common::GetRealPath(file_path);
|
||||
|
|
|
@ -65,6 +65,7 @@ class BaseRecorder {
|
|||
|
||||
void SetDirectory(const std::string &directory);
|
||||
void SetFilename(const std::string &filename);
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
virtual void Export() {}
|
||||
|
||||
protected:
|
||||
|
@ -73,6 +74,7 @@ class BaseRecorder {
|
|||
std::string directory_;
|
||||
std::string filename_;
|
||||
std::string timestamp_; // year,month,day,hour,minute,second
|
||||
std::string delimiter_{"."};
|
||||
};
|
||||
using BaseRecorderPtr = std::shared_ptr<BaseRecorder>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,11 +46,11 @@ bool DumpGraphExeOrder(const std::string &filename, const std::vector<CNodePtr>
|
|||
} // namespace
|
||||
|
||||
void GraphExecOrderRecorder::Export() {
|
||||
auto realpath = GetFileRealPath();
|
||||
auto realpath = GetFileRealPath(std::to_string(graph_id_));
|
||||
if (!realpath.has_value()) {
|
||||
return;
|
||||
}
|
||||
std::string real_file_path = realpath.value() + std::to_string(graph_id_);
|
||||
std::string real_file_path = realpath.value() + ".txt";
|
||||
DumpGraphExeOrder(real_file_path, exec_order_);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,6 @@ class GraphExecOrderRecorder : public BaseRecorder {
|
|||
GraphExecOrderRecorder(const std::string &module, const std::string &tag,
|
||||
const std::vector<CNodePtr> &final_exec_order, int graph_id)
|
||||
: BaseRecorder(module, tag), exec_order_(final_exec_order), graph_id_(graph_id) {}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
void SetExecOrder(const std::vector<CNodePtr> &final_exec_order) { exec_order_ = final_exec_order; }
|
||||
virtual void Export();
|
||||
|
||||
|
|
|
@ -65,6 +65,7 @@ void GraphRecorder::Export() {
|
|||
}
|
||||
std::string suffix = graph_id >= 0 ? std::to_string(graph_id) : "";
|
||||
auto tmp_realpath = GetFileRealPath(suffix);
|
||||
|
||||
if (!tmp_realpath.has_value()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -31,7 +31,6 @@ class GraphRecorder : public BaseRecorder {
|
|||
const std::string &file_type)
|
||||
: BaseRecorder(module, tag), func_graph_(graph), graph_type_(file_type) {}
|
||||
~GraphRecorder() {}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
void SetGraphType(const std::string &file_type) { graph_type_ = file_type; }
|
||||
void SetFuncGraph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; }
|
||||
void SetDumpFlag(bool full_name) { full_name_ = full_name; }
|
||||
|
|
|
@ -21,7 +21,35 @@
|
|||
#include "mindspore/core/ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
void RecorderManager::UpdateRdrEnable() {
|
||||
static bool updated = false;
|
||||
if (updated) {
|
||||
return;
|
||||
}
|
||||
auto &config_parser = mindspore::EnvConfigParser::GetInstance();
|
||||
rdr_enable_ = config_parser.rdr_enabled();
|
||||
if (config_parser.has_rdr_setting()) {
|
||||
#ifdef __linux__
|
||||
if (!rdr_enable_) {
|
||||
MS_LOG(WARNING) << "Please set the 'enable' as true using 'rdr' setting in file '" << config_parser.config_path()
|
||||
<< "' if you want to use RDR.";
|
||||
}
|
||||
#else
|
||||
if (rdr_enable_) {
|
||||
MS_LOG(WARNING) << "The RDR only supports linux os currently.";
|
||||
}
|
||||
rdr_enable_ = false;
|
||||
#endif
|
||||
}
|
||||
updated = true;
|
||||
}
|
||||
|
||||
bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
|
||||
if (!rdr_enable_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (recorder == nullptr) {
|
||||
MS_LOG(ERROR) << "register recorder module with nullptr.";
|
||||
return false;
|
||||
|
@ -33,10 +61,7 @@ bool RecorderManager::RecordObject(const BaseRecorderPtr &recorder) {
|
|||
}
|
||||
|
||||
void RecorderManager::TriggerAll() {
|
||||
auto &config_parser_ptr = mindspore::EnvConfigParser::GetInstance();
|
||||
config_parser_ptr.Parse();
|
||||
if (!config_parser_ptr.rdr_enabled()) {
|
||||
MS_LOG(INFO) << "RDR is not enable.";
|
||||
if (!rdr_enable_) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -31,9 +31,12 @@ class RecorderManager {
|
|||
public:
|
||||
static RecorderManager &Instance() {
|
||||
static RecorderManager manager;
|
||||
manager.UpdateRdrEnable();
|
||||
return manager;
|
||||
}
|
||||
|
||||
void UpdateRdrEnable();
|
||||
bool RdrEnable() const { return rdr_enable_; }
|
||||
bool RecordObject(const BaseRecorderPtr &recorder);
|
||||
void TriggerAll();
|
||||
void ClearAll();
|
||||
|
@ -42,6 +45,8 @@ class RecorderManager {
|
|||
RecorderManager() {}
|
||||
~RecorderManager() {}
|
||||
|
||||
bool rdr_enable_{false};
|
||||
|
||||
mutable std::mutex mtx_;
|
||||
// module, BaserRecorderPtrList
|
||||
std::unordered_map<std::string, BaseRecorderPtrList> recorder_container_;
|
||||
|
|
|
@ -65,6 +65,9 @@ namespace RDR {
|
|||
#ifdef ENABLE_D
|
||||
bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
|
||||
const std::vector<TaskDebugInfoPtr> &task_debug_info_list, int graph_id) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
TaskDebugInfoRecorderPtr task_debug_info_recorder =
|
||||
std::make_shared<TaskDebugInfoRecorder>(submodule_name, tag, task_debug_info_list, graph_id);
|
||||
|
@ -73,9 +76,11 @@ bool RecordTaskDebugInfo(SubModuleId module, const std::string &tag,
|
|||
}
|
||||
#endif // ENABLE_D
|
||||
|
||||
#ifdef __linux__
|
||||
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
|
||||
const std::string &file_type) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
GraphRecorderPtr graph_recorder = std::make_shared<GraphRecorder>(submodule_name, tag, graph, file_type);
|
||||
graph_recorder->SetDumpFlag(full_name);
|
||||
|
@ -85,6 +90,9 @@ bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const Func
|
|||
|
||||
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
|
||||
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
GraphExecOrderRecorderPtr graph_exec_order_recorder =
|
||||
std::make_shared<GraphExecOrderRecorder>(submodule_name, tag, final_exec_order, graph_id);
|
||||
|
@ -93,6 +101,9 @@ bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
|
|||
}
|
||||
|
||||
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
StringRecorderPtr string_recorder = std::make_shared<StringRecorder>(submodule_name, tag, data, filename);
|
||||
string_recorder->SetFilename(filename);
|
||||
|
@ -102,6 +113,9 @@ bool RecordString(SubModuleId module, const std::string &tag, const std::string
|
|||
|
||||
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
|
||||
const std::vector<CNodePtr> &exec_order) {
|
||||
if (!mindspore::RecorderManager::Instance().RdrEnable()) {
|
||||
return false;
|
||||
}
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
StreamExecOrderRecorderPtr stream_exec_order_recorder =
|
||||
std::make_shared<StreamExecOrderRecorder>(submodule_name, tag, graph_id, exec_order);
|
||||
|
@ -113,68 +127,5 @@ void TriggerAll() { mindspore::RecorderManager::Instance().TriggerAll(); }
|
|||
|
||||
void ClearAll() { mindspore::RecorderManager::Instance().ClearAll(); }
|
||||
|
||||
#else
|
||||
bool RecordAnfGraph(const SubModuleId module, const std::string &tag, const FuncGraphPtr &graph, bool full_name,
|
||||
const std::string &file_type) {
|
||||
static bool already_printed = false;
|
||||
std::string submodule_name = std::string(GetSubModuleName(module));
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os " << submodule_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecordGraphExecOrder(const SubModuleId module, const std::string &tag,
|
||||
const std::vector<CNodePtr> &final_exec_order, int graph_id) {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecordString(SubModuleId module, const std::string &tag, const std::string &data, const std::string &filename) {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
return false;
|
||||
}
|
||||
|
||||
bool RecordStreamExecOrder(const SubModuleId module, const std::string &tag, const int &graph_id,
|
||||
const std::vector<CNodePtr> &exec_order) {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return false;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
return false;
|
||||
}
|
||||
|
||||
void TriggerAll() {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
}
|
||||
|
||||
void ClearAll() {
|
||||
static bool already_printed = false;
|
||||
if (already_printed) {
|
||||
return;
|
||||
}
|
||||
already_printed = true;
|
||||
MS_LOG(WARNING) << "The RDR presently only support linux os.";
|
||||
}
|
||||
#endif // __linux__
|
||||
} // namespace RDR
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -98,7 +98,6 @@ class StreamExecOrderRecorder : public BaseRecorder {
|
|||
exec_order_.push_back(std::move(exec_node_ptr));
|
||||
}
|
||||
}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
virtual void Export();
|
||||
|
||||
private:
|
||||
|
|
|
@ -26,14 +26,14 @@ class StringRecorder : public BaseRecorder {
|
|||
StringRecorder() : BaseRecorder() {}
|
||||
StringRecorder(const std::string &module, const std::string &tag, const std::string &data,
|
||||
const std::string &filename)
|
||||
: BaseRecorder(module, tag), data_(data), filename_(filename) {}
|
||||
: BaseRecorder(module, tag), data_(data) {
|
||||
SetFilename(filename);
|
||||
}
|
||||
~StringRecorder() {}
|
||||
void SetModule(const std::string &module) { module_ = module; }
|
||||
virtual void Export();
|
||||
|
||||
private:
|
||||
std::string data_;
|
||||
std::string filename_;
|
||||
};
|
||||
using StringRecorderPtr = std::shared_ptr<StringRecorder>;
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -258,8 +258,8 @@ class _Context:
|
|||
def set_env_config_path(self, env_config_path):
|
||||
"""Check and set env_config_path."""
|
||||
if not self._context_handle.enable_dump_ir():
|
||||
raise ValueError("The 'env_config_path' is not supported, please turn on ENABLE_DUMP_IR "
|
||||
"and recompile source to enable it.")
|
||||
raise ValueError("The 'env_config_path' is not supported, please enable ENABLE_DUMP_IR "
|
||||
"with '-D on' and recompile source.")
|
||||
env_config_path = os.path.realpath(env_config_path)
|
||||
if not os.path.isfile(env_config_path):
|
||||
raise ValueError("The %r set by 'env_config_path' should be an existing json file." % env_config_path)
|
||||
|
|
Loading…
Reference in New Issue