forked from mindspore-Ecosystem/mindspore
!39990 Use the "best_split" tuning result in kernel_meta
Merge pull request !39990 from DeshiChen/0801_split_tune_bin
This commit is contained in:
commit
cc14e844bd
|
@ -204,6 +204,9 @@ class CostModelSplitSchemer : public SplitSchemer {
|
||||||
MS_LOG(ERROR) << "Collect json desc failed.";
|
MS_LOG(ERROR) << "Collect json desc failed.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
// set the "node_name" for tracing split result.
|
||||||
|
std::string node_name = json_desc["op"];
|
||||||
|
func_graph_->set_attr(kAttrNodeName, MakeValue(node_name));
|
||||||
|
|
||||||
// call costmodel split function.
|
// call costmodel split function.
|
||||||
auto json_desc_str = json_desc.dump();
|
auto json_desc_str = json_desc.dump();
|
||||||
|
|
|
@ -375,6 +375,7 @@ class Splitter {
|
||||||
void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
|
void RebuildGraph(const std::vector<size_t> &cnodes_group_id) {
|
||||||
BindFuncGraph();
|
BindFuncGraph();
|
||||||
RecoverParameter();
|
RecoverParameter();
|
||||||
|
SetSplitNodeName(cnodes_group_id);
|
||||||
ConnectToMainGraph(cnodes_group_id);
|
ConnectToMainGraph(cnodes_group_id);
|
||||||
UpdateSubGraphInfo();
|
UpdateSubGraphInfo();
|
||||||
ResetInlinedNodesKernelInfo();
|
ResetInlinedNodesKernelInfo();
|
||||||
|
@ -443,6 +444,23 @@ class Splitter {
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void SetSplitNodeName(const std::vector<size_t> &cnodes_group_id) const {
|
||||||
|
auto old_func_graph = GetCNodeFuncGraph(old_subgraph_cnode_);
|
||||||
|
std::string ori_node_name;
|
||||||
|
if (old_func_graph->has_attr(kAttrNodeName)) {
|
||||||
|
ori_node_name = GetValue<std::string>(old_func_graph->get_attr(kAttrNodeName));
|
||||||
|
} else {
|
||||||
|
ori_node_name = GetValue<std::string>(old_func_graph->get_attr("graph_kernel"));
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < new_subgraph_cnodes_.size(); ++i) {
|
||||||
|
auto group_id = cnodes_group_id[i];
|
||||||
|
if (!split_schemer_->NeedInline(group_id)) {
|
||||||
|
std::string node_name = ori_node_name + "_" + std::to_string(group_id);
|
||||||
|
AnfUtils::SetNodeAttr(kAttrNodeName, MakeValue(node_name), new_subgraph_cnodes_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Set the new sub_func_graph node as input of nodes original main graph.
|
// Set the new sub_func_graph node as input of nodes original main graph.
|
||||||
void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
|
void ConnectToMainGraph(const std::vector<size_t> &cnodes_group_id) {
|
||||||
// For single output kernel, the last area contains the original output node (return node),
|
// For single output kernel, the last area contains the original output node (return node),
|
||||||
|
|
|
@ -57,15 +57,15 @@ bool TuningSplitSchemer::ParseResult(const AnfNodePtrList &nodes, const nlohmann
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TuningSplitSchemer::Split(const FuncGraphPtr &func_graph) {
|
bool TuningSplitSchemer::Split(const FuncGraphPtr &func_graph) {
|
||||||
if (!func_graph->has_attr("kernel_name")) {
|
if (!func_graph->has_attr(kAttrNodeName)) {
|
||||||
MS_LOG(WARNING) << "The func_graph has not attr \"kernel_name\".";
|
MS_LOG(WARNING) << "The func_graph has not attr \"node_name\".";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::string kernel_name = GetValue<std::string>(func_graph->get_attr("kernel_name"));
|
std::string node_name = GetValue<std::string>(func_graph->get_attr(kAttrNodeName));
|
||||||
AnfNodePtrList nodes;
|
AnfNodePtrList nodes;
|
||||||
GkUtils::GetValidKernelNodes(func_graph, &nodes, nullptr, nullptr);
|
GkUtils::GetValidKernelNodes(func_graph, &nodes, nullptr, nullptr);
|
||||||
// the input json has postfix ".info", and the result file has postfix ".json"
|
// the input json has postfix ".info", and the result file has postfix ".json"
|
||||||
auto result_file = tuning_path_ + "/" + kernel_name + ".json";
|
auto result_file = tuning_path_ + "/" + node_name + ".json";
|
||||||
nlohmann::json tuning_result;
|
nlohmann::json tuning_result;
|
||||||
if (!ReadCache(result_file, &tuning_result)) {
|
if (!ReadCache(result_file, &tuning_result)) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -742,8 +742,10 @@ bool AkgKernelJsonGenerator::CollectJson(const AnfNodePtr &anf_node, nlohmann::j
|
||||||
|
|
||||||
// gen hash id with the above info.
|
// gen hash id with the above info.
|
||||||
size_t hash_id = GenHashId(kernel_json->dump());
|
size_t hash_id = GenHashId(kernel_json->dump());
|
||||||
kernel_name_ = op_name + "_";
|
kernel_name_ = op_name + "_" + std::to_string(hash_id);
|
||||||
(void)kernel_name_.append(std::to_string(hash_id));
|
if (dump_option_.gen_kernel_name_only) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
||||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||||
|
@ -847,12 +849,17 @@ bool AkgKernelJsonGenerator::CollectFusedJson(const std::vector<AnfNodePtr> &anf
|
||||||
auto fg = anf_nodes[0]->func_graph();
|
auto fg = anf_nodes[0]->func_graph();
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
GenKernelName(fg, hash_id, kernel_json);
|
GenKernelName(fg, hash_id, kernel_json);
|
||||||
|
if (dump_option_.gen_kernel_name_only) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
(*kernel_json)[kJsonKeyId] = 0; // unused key
|
||||||
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
(*kernel_json)[kJsonKeyOp] = kernel_name_;
|
||||||
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
(*kernel_json)[kJsonKeyPlatform] = "AKG";
|
||||||
(*kernel_json)[kJsonKeyComposite] = true;
|
(*kernel_json)[kJsonKeyComposite] = true;
|
||||||
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
|
(*kernel_json)[kJsonKeyCompositeGraph] = fg->ToString();
|
||||||
|
if (fg->has_attr(kAttrNodeName)) {
|
||||||
|
(*kernel_json)[kJsonKeyNodeName] = GetValue<std::string>(fg->get_attr(kAttrNodeName));
|
||||||
|
}
|
||||||
|
|
||||||
GetIOSize(*kernel_json, &input_size_list_, &output_size_list_);
|
GetIOSize(*kernel_json, &input_size_list_, &output_size_list_);
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,7 @@ constexpr auto kJsonKeySmCount = "sm_count";
|
||||||
constexpr auto kJsonKeySystem = "system";
|
constexpr auto kJsonKeySystem = "system";
|
||||||
constexpr auto kJsonKeyArch = "arch";
|
constexpr auto kJsonKeyArch = "arch";
|
||||||
constexpr auto kJsonKeyFeature = "feature";
|
constexpr auto kJsonKeyFeature = "feature";
|
||||||
|
constexpr auto kJsonKeyNodeName = "node_name";
|
||||||
|
|
||||||
// dump option
|
// dump option
|
||||||
struct DumpOption {
|
struct DumpOption {
|
||||||
|
@ -77,6 +78,7 @@ struct DumpOption {
|
||||||
bool save_ptr_address = false;
|
bool save_ptr_address = false;
|
||||||
bool extract_opinfo_from_anfnode = false;
|
bool extract_opinfo_from_anfnode = false;
|
||||||
bool get_target_info = false;
|
bool get_target_info = false;
|
||||||
|
bool gen_kernel_name_only = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TargetInfoSetter {
|
class TargetInfoSetter {
|
||||||
|
|
|
@ -25,12 +25,12 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||||
#include "kernel/akg/akg_kernel_json_generator.h"
|
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/func_graph.h"
|
#include "ir/func_graph.h"
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
#include "utils/file_utils.h"
|
#include "utils/file_utils.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "utils/system/env.h"
|
||||||
|
|
||||||
namespace mindspore::graphkernel {
|
namespace mindspore::graphkernel {
|
||||||
bool CompileSingleJson(const std::string &json_name) {
|
bool CompileSingleJson(const std::string &json_name) {
|
||||||
|
@ -43,7 +43,7 @@ bool CompileSingleJson(const std::string &json_name) {
|
||||||
py_cmd << "if not compilewithjsonname(\'" << json_name << "\', " << attrs << "):\n";
|
py_cmd << "if not compilewithjsonname(\'" << json_name << "\', " << attrs << "):\n";
|
||||||
py_cmd << " raise RuntimeError(\'Compile fail for json: " << json_name << "\')";
|
py_cmd << " raise RuntimeError(\'Compile fail for json: " << json_name << "\')";
|
||||||
std::string cmd = "unset LD_LIBRARY_PATH;python -c \"" + py_cmd.str() + "\"";
|
std::string cmd = "unset LD_LIBRARY_PATH;python -c \"" + py_cmd.str() + "\"";
|
||||||
auto ret = system(cmd.c_str());
|
auto ret = std::system(cmd.c_str());
|
||||||
if (!WIFEXITED(ret)) {
|
if (!WIFEXITED(ret)) {
|
||||||
MS_LOG(ERROR) << "Python process start fail! process content is as follows:\n" << cmd;
|
MS_LOG(ERROR) << "Python process start fail! process content is as follows:\n" << cmd;
|
||||||
return false;
|
return false;
|
||||||
|
@ -131,37 +131,15 @@ bool SaveJsonInfo(const std::string &json_name, const std::string &info) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CheckObjFiles(const std::string &dir_path, const std::vector<std::string> &json_list) {
|
std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir, const DumpOption &option,
|
||||||
constexpr size_t try_times = 10;
|
std::map<AnfNodePtr, std::string> *node_kernel, std::vector<std::string> *kernel_names) {
|
||||||
constexpr size_t wait_us = 100000;
|
|
||||||
for (auto const json_name : json_list) {
|
|
||||||
auto file_name = dir_path + "/" + json_name + ".json";
|
|
||||||
bool exist = false;
|
|
||||||
for (size_t i = 0; i < try_times; ++i) {
|
|
||||||
std::ifstream f(file_name.c_str());
|
|
||||||
if (f.good()) {
|
|
||||||
exist = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
usleep(wait_us);
|
|
||||||
}
|
|
||||||
if (!exist) {
|
|
||||||
MS_LOG(EXCEPTION) << "akg file " << json_name << ".json not exist!";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir,
|
|
||||||
std::map<AnfNodePtr, std::string> *node_kernel, std::vector<std::string> *json_list) {
|
|
||||||
auto dir_path = FileUtils::CreateNotExistDirs(dir);
|
auto dir_path = FileUtils::CreateNotExistDirs(dir);
|
||||||
if (!dir_path.has_value()) {
|
if (!dir_path.has_value()) {
|
||||||
MS_LOG(ERROR) << "Failed to CreateNotExistDirs: " << dir;
|
MS_LOG(ERROR) << "Failed to CreateNotExistDirs: " << dir;
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
std::vector<std::string> kernel_names;
|
std::vector<std::string> unique_kernel_name;
|
||||||
for (const auto &node : nodes) {
|
for (const auto &node : nodes) {
|
||||||
graphkernel::DumpOption option;
|
|
||||||
option.get_target_info = true;
|
|
||||||
graphkernel::AkgKernelJsonGenerator akg_kernel_json_generator(option);
|
graphkernel::AkgKernelJsonGenerator akg_kernel_json_generator(option);
|
||||||
auto fg = GetCNodeFuncGraph(node);
|
auto fg = GetCNodeFuncGraph(node);
|
||||||
MS_EXCEPTION_IF_NULL(fg);
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
@ -177,39 +155,59 @@ std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir,
|
||||||
if (node_kernel != nullptr) {
|
if (node_kernel != nullptr) {
|
||||||
(*node_kernel)[node] = json_kernel_name;
|
(*node_kernel)[node] = json_kernel_name;
|
||||||
}
|
}
|
||||||
if (find(kernel_names.cbegin(), kernel_names.cend(), json_kernel_name) != kernel_names.cend()) {
|
if (find(unique_kernel_name.cbegin(), unique_kernel_name.cend(), json_kernel_name) != unique_kernel_name.cend()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
kernel_names.push_back(json_kernel_name);
|
unique_kernel_name.push_back(json_kernel_name);
|
||||||
if (!SaveJsonInfo(dir_path.value() + "/" + json_kernel_name, akg_kernel_json_generator.kernel_json_str())) {
|
if (!SaveJsonInfo(dir_path.value() + "/" + json_kernel_name, akg_kernel_json_generator.kernel_json_str())) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (json_list != nullptr) {
|
if (kernel_names != nullptr) {
|
||||||
*json_list = std::move(kernel_names);
|
*kernel_names = std::move(unique_kernel_name);
|
||||||
}
|
}
|
||||||
return dir_path.value();
|
return dir_path.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ExcludeCachedObj(const std::string &dir_path, std::vector<std::string> *kernel_names,
|
||||||
|
std::vector<std::string> *obj_files) {
|
||||||
|
auto fs = system::Env::GetFileSystem();
|
||||||
|
std::vector<std::string> new_kernel_names;
|
||||||
|
for (const auto &name : *kernel_names) {
|
||||||
|
auto tuned_op_cache = dir_path + "/Tuned_" + name + ".o";
|
||||||
|
if (fs->FileExist(tuned_op_cache)) {
|
||||||
|
MS_LOG(INFO) << "Reuse the object file " << tuned_op_cache;
|
||||||
|
(void)obj_files->emplace_back(std::move(tuned_op_cache));
|
||||||
|
} else {
|
||||||
|
new_kernel_names.push_back(name);
|
||||||
|
(void)obj_files->emplace_back(dir_path + "/" + name + ".o");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (new_kernel_names.size() < kernel_names->size()) {
|
||||||
|
*kernel_names = std::move(new_kernel_names);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool AkgKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) {
|
bool AkgKernelBuilder::CompileJsonsInAnfnodes(const AnfNodePtrList &node_list) {
|
||||||
std::map<AnfNodePtr, std::string> node_name;
|
std::map<AnfNodePtr, std::string> node_name;
|
||||||
std::vector<std::string> json_list;
|
std::vector<std::string> kernel_names;
|
||||||
auto dir_path = SaveNodesInfo(node_list, "./kernel_meta", &node_name, &json_list);
|
auto dir_path = SaveNodesInfo(node_list, "./kernel_meta", AkgKernelBuilder::json_option(), &node_name, &kernel_names);
|
||||||
if (dir_path.empty()) {
|
if (dir_path.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto res = CompileJsonsInList(dir_path, json_list);
|
std::vector<std::string> obj_files;
|
||||||
|
ExcludeCachedObj(dir_path, &kernel_names, &obj_files);
|
||||||
|
auto res = CompileJsonsInList(dir_path, kernel_names);
|
||||||
if (res) {
|
if (res) {
|
||||||
for (const auto &iter : node_name) {
|
for (const auto &iter : node_name) {
|
||||||
AnfUtils::SetNodeAttr("kernel_name", MakeValue(iter.second + "_kernel"), iter.first);
|
AnfUtils::SetNodeAttr("kernel_name", MakeValue(iter.second + "_kernel"), iter.first);
|
||||||
}
|
}
|
||||||
std::ostringstream kernels_name;
|
std::ostringstream objs;
|
||||||
for (const auto &json_kernel_name : json_list) {
|
for (const auto &obj : obj_files) {
|
||||||
kernels_name << dir_path << "/" << json_kernel_name << ".o ";
|
objs << obj << " ";
|
||||||
}
|
}
|
||||||
CheckObjFiles(dir_path, json_list);
|
auto cmd = "g++ -fPIC -shared -o akgkernels.so " + objs.str();
|
||||||
auto cmd = "g++ -fPIC -shared -o akgkernels.so " + kernels_name.str();
|
if (std::system(cmd.c_str()) == 0) {
|
||||||
if (system(cmd.c_str()) == 0) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
|
#include "kernel/akg/akg_kernel_json_generator.h"
|
||||||
|
|
||||||
namespace mindspore::graphkernel {
|
namespace mindspore::graphkernel {
|
||||||
constexpr size_t PROCESS_LIMIT = 8;
|
constexpr size_t PROCESS_LIMIT = 8;
|
||||||
|
@ -31,9 +32,15 @@ class AkgKernelBuilder {
|
||||||
~AkgKernelBuilder() = default;
|
~AkgKernelBuilder() = default;
|
||||||
|
|
||||||
bool CompileJsonsInAnfnodes(const AnfNodePtrList &node_list);
|
bool CompileJsonsInAnfnodes(const AnfNodePtrList &node_list);
|
||||||
|
|
||||||
|
static DumpOption json_option() {
|
||||||
|
DumpOption dump_json_option;
|
||||||
|
dump_json_option.get_target_info = true;
|
||||||
|
return dump_json_option;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir,
|
std::string SaveNodesInfo(const AnfNodePtrList &nodes, const std::string &dir, const DumpOption &option,
|
||||||
std::map<AnfNodePtr, std::string> *node_name, std::vector<std::string> *json_list);
|
std::map<AnfNodePtr, std::string> *node_name, std::vector<std::string> *kernel_names);
|
||||||
} // namespace mindspore::graphkernel
|
} // namespace mindspore::graphkernel
|
||||||
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_AKG_BUILD_H_
|
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_AKG_AKG_BUILD_H_
|
||||||
|
|
|
@ -16,9 +16,13 @@
|
||||||
#include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h"
|
#include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h"
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <cstdio>
|
||||||
|
#include "utils/system/env.h"
|
||||||
|
#include "utils/file_utils.h"
|
||||||
#include "utils/anf_utils.h"
|
#include "utils/anf_utils.h"
|
||||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||||
#include "common/graph_kernel/core/tuning_splitter.h"
|
#include "common/graph_kernel/core/tuning_splitter.h"
|
||||||
|
#include "common/graph_kernel/core/graph_kernel_utils.h"
|
||||||
#include "tools/graph_kernel/converter/akg/akg_build.h"
|
#include "tools/graph_kernel/converter/akg/akg_build.h"
|
||||||
|
|
||||||
namespace mindspore::graphkernel {
|
namespace mindspore::graphkernel {
|
||||||
|
@ -43,10 +47,10 @@ bool GraphKernelSplitterWithTuning::StartTuning(const std::string &dir_path) con
|
||||||
py_cmd << "import sys; sys.path.insert(0, get_akg_path())\n";
|
py_cmd << "import sys; sys.path.insert(0, get_akg_path())\n";
|
||||||
py_cmd << "from akg.ms import " << tune_interface << "\n";
|
py_cmd << "from akg.ms import " << tune_interface << "\n";
|
||||||
py_cmd << "if not " << tune_interface << "(\'" << dir_path << "\', " << attrs.str() << "):\n";
|
py_cmd << "if not " << tune_interface << "(\'" << dir_path << "\', " << attrs.str() << "):\n";
|
||||||
py_cmd << " raise RuntimeError(\'Tune fail for json: " << dir_path << "\')";
|
py_cmd << " raise RuntimeError(\'Tune fail. info path: " << dir_path << "\')";
|
||||||
std::string cmd = "unset LD_LIBRARY_PATH;python -c \"" + py_cmd.str() + "\"";
|
std::string cmd = "unset LD_LIBRARY_PATH;python -c \"" + py_cmd.str() + "\"";
|
||||||
MS_LOG(INFO) << "GraphKernel online tuning content: \n" << cmd;
|
MS_LOG(INFO) << "GraphKernel online tuning content: \n" << cmd;
|
||||||
auto ret = system(cmd.c_str());
|
auto ret = std::system(cmd.c_str());
|
||||||
if (!WIFEXITED(ret)) {
|
if (!WIFEXITED(ret)) {
|
||||||
MS_LOG(ERROR) << "Python process start fail! process content is as follows:\n" << cmd;
|
MS_LOG(ERROR) << "Python process start fail! process content is as follows:\n" << cmd;
|
||||||
return false;
|
return false;
|
||||||
|
@ -58,6 +62,40 @@ bool GraphKernelSplitterWithTuning::StartTuning(const std::string &dir_path) con
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RenameKernelFiles(const FuncGraphPtr &func_graph) {
|
||||||
|
auto kernel_meta = FileUtils::GetRealPath("./kernel_meta/");
|
||||||
|
if (!kernel_meta.has_value()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto fs = system::Env::GetFileSystem();
|
||||||
|
MS_EXCEPTION_IF_NULL(fs);
|
||||||
|
DumpOption option = AkgKernelBuilder::json_option();
|
||||||
|
option.gen_kernel_name_only = true;
|
||||||
|
|
||||||
|
auto todos = TopoSort(func_graph->get_return());
|
||||||
|
for (const auto &node : todos) {
|
||||||
|
if (!AnfUtils::IsGraphKernel(node)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto fg = GetCNodeFuncGraph(node);
|
||||||
|
if (!fg->has_attr(kAttrNodeName)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto node_name = GetValue<std::string>(fg->get_attr(kAttrNodeName));
|
||||||
|
auto kernel_obj = kernel_meta.value() + "/best_split_" + node_name + ".o";
|
||||||
|
if (fs->FileExist(kernel_obj)) {
|
||||||
|
AkgKernelJsonGenerator json_generator(option);
|
||||||
|
std::vector<AnfNodePtr> node_list, input_list, output_list;
|
||||||
|
GkUtils::GetValidKernelNodes(fg, &node_list, &input_list, &output_list);
|
||||||
|
(void)json_generator.CollectFusedJson(node_list, input_list, output_list);
|
||||||
|
auto new_kernel_obj = kernel_meta.value() + "/Tuned_" + json_generator.kernel_name() + ".o";
|
||||||
|
// only rename file, but not change the "node_name" attr.
|
||||||
|
MS_LOG(INFO) << "Rename " << kernel_obj << " to " << new_kernel_obj;
|
||||||
|
std::rename(kernel_obj.c_str(), new_kernel_obj.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool GraphKernelSplitterWithTuning::Run(const FuncGraphPtr &func_graph) {
|
bool GraphKernelSplitterWithTuning::Run(const FuncGraphPtr &func_graph) {
|
||||||
if (GraphKernelFlags::GetInstance().online_tuning == 0) {
|
if (GraphKernelFlags::GetInstance().online_tuning == 0) {
|
||||||
tuning_flag_ = false;
|
tuning_flag_ = false;
|
||||||
|
@ -70,15 +108,20 @@ bool GraphKernelSplitterWithTuning::Run(const FuncGraphPtr &func_graph) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::map<AnfNodePtr, std::string> node_name;
|
std::map<AnfNodePtr, std::string> node_name;
|
||||||
tuning_path_ = SaveNodesInfo(gknodes, "./split_tuning", &node_name, nullptr);
|
tuning_path_ = SaveNodesInfo(gknodes, "./split_tuning", AkgKernelBuilder::json_option(), &node_name, nullptr);
|
||||||
if (tuning_path_.empty()) {
|
if (tuning_path_.empty()) {
|
||||||
tuning_flag_ = false;
|
tuning_flag_ = false;
|
||||||
} else {
|
} else {
|
||||||
tuning_flag_ = StartTuning(tuning_path_);
|
tuning_flag_ = StartTuning(tuning_path_);
|
||||||
}
|
}
|
||||||
for (const auto &iter : node_name) {
|
for (const auto &iter : node_name) {
|
||||||
AnfUtils::SetNodeAttr("kernel_name", MakeValue(iter.second), iter.first);
|
AnfUtils::SetNodeAttr(kAttrNodeName, MakeValue(iter.second), iter.first);
|
||||||
}
|
}
|
||||||
return GraphKernelSplitter::Run(func_graph);
|
auto changed = GraphKernelSplitter::Run(func_graph);
|
||||||
|
if (!changed) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
RenameKernelFiles(func_graph);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
} // namespace mindspore::graphkernel
|
} // namespace mindspore::graphkernel
|
||||||
|
|
Loading…
Reference in New Issue