Support compile cache in ps mode
This commit is contained in:
parent
ce72110c7c
commit
dcd6f9e491
|
@ -25,6 +25,7 @@
|
|||
"mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/cast_gpu_kernel.cc" "unknownMacro"
|
||||
"mindspore/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc" "nullPointerArithmeticRedundantCheck"
|
||||
"mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc" "containerOutOfBounds"
|
||||
"mindspore/mindspore/core/load_mindir/anf_model_parser.cc" "stlIfStrFind"
|
||||
|
||||
# MindData
|
||||
"mindspore/mindspore/ccsrc/minddata/dataset/engine/dataset_iterator.cc" "useStlAlgorithm"
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -32,7 +32,8 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
|
|||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
|
||||
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, const FuncGraphPtr ¶m_layout_fg = nullptr);
|
||||
bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path,
|
||||
const FuncGraphPtr ¶m_layout_fg = nullptr);
|
||||
|
||||
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -199,35 +199,10 @@ FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) {
|
|||
return bprop_fg;
|
||||
}
|
||||
|
||||
void ExportBpropToMindIR(const std::string &prim_name, const FuncGraphPtr &func_graph) {
|
||||
std::string bprop_dir = GetBpropDir();
|
||||
auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
|
||||
std::optional<std::string> bprop_mindir_realpath =
|
||||
Common::CreatePrefixPath(bprop_mindir_path + prim_name + kBpropMindIRSuffix, true);
|
||||
if (!bprop_mindir_realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim_name
|
||||
<< kBpropMindIRSuffix;
|
||||
return;
|
||||
}
|
||||
std::ofstream fout(bprop_mindir_realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open cache file '" << bprop_mindir_realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return;
|
||||
}
|
||||
ModelProtoPtr fg_model = GetBinaryProto(func_graph);
|
||||
if (fg_model == nullptr) {
|
||||
MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
|
||||
fout.close();
|
||||
return;
|
||||
}
|
||||
if (!fg_model->SerializeToOstream(&fout)) {
|
||||
MS_LOG(ERROR) << "Failed to cache the bprop of op \"" << prim_name << "\" to file \""
|
||||
<< bprop_mindir_realpath.value() << "\".";
|
||||
fout.close();
|
||||
return;
|
||||
}
|
||||
fout.close();
|
||||
ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);
|
||||
bool ExportBpropToMindIR(const std::string &prim_name, const FuncGraphPtr &func_graph) {
|
||||
static auto bprop_mindir_dir = GetBpropDir() + kBpropMindIRDir;
|
||||
std::string bprop_mindir_path = bprop_mindir_dir + prim_name + kBpropMindIRSuffix;
|
||||
return DumpBinaryProto(func_graph, bprop_mindir_path);
|
||||
}
|
||||
|
||||
AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, const PrimitivePtr &prim) {
|
||||
|
@ -319,6 +294,7 @@ bool NeedExportBpropMindIR(const std::string &prim_name, const std::string &curr
|
|||
} // namespace
|
||||
|
||||
#ifndef _WIN32
|
||||
// For the bprop mindir generator.
|
||||
// Given a python primitive or string, export a mindir file from the bprop defined in python.
|
||||
void KPrim::ExportBpropMindir(const py::object &obj) {
|
||||
std::string prim_name;
|
||||
|
@ -353,7 +329,9 @@ void KPrim::ExportBpropMindir(const py::object &obj) {
|
|||
(void)parse::ResolveFuncGraph(func_graph, res);
|
||||
|
||||
func_graph->set_bprop_hash(bprop_hash);
|
||||
ExportBpropToMindIR(prim_name, func_graph);
|
||||
if (!ExportBpropToMindIR(prim_name, func_graph)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to export the bprop mindir for " << prim_name;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -412,7 +390,9 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceB
|
|||
std::string bprop_hash = GetBpropFileHash(fn);
|
||||
if (!bprop_hash.empty()) {
|
||||
func_graph->set_bprop_hash(bprop_hash);
|
||||
ExportBpropToMindIR(prim->name(), func_graph);
|
||||
if (!ExportBpropToMindIR(prim->name(), func_graph)) {
|
||||
MS_LOG(WARNING) << "Failed to export the bprop mindir for " << prim->name();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -56,7 +56,7 @@ py::dict GetParameterLayoutFromGraph(const FuncGraphPtr &graph) {
|
|||
|
||||
py::dict GetParameterLayoutFromResource(const pipeline::ResourcePtr &resource) {
|
||||
py::dict dict;
|
||||
const auto &layout_map = resource->get_layout_map();
|
||||
const auto &layout_map = resource->layout_map();
|
||||
for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
|
||||
auto name = iter->first;
|
||||
auto layout = iter->second;
|
||||
|
@ -116,7 +116,7 @@ py::list GetParallelParameterNameListFromGraph(const FuncGraphPtr &graph) {
|
|||
}
|
||||
|
||||
py::list GetParallelParameterNameListFromResource(const pipeline::ResourcePtr &resource) {
|
||||
auto &layout_map = resource->get_layout_map();
|
||||
auto &layout_map = resource->layout_map();
|
||||
py::list parallel_parameter_name_list;
|
||||
for (auto iter = layout_map.begin(); iter != layout_map.end(); ++iter) {
|
||||
auto name = iter->first;
|
||||
|
|
|
@ -6,6 +6,7 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"validator.cc"
|
||||
"remove_value_node_dup.cc"
|
||||
"pipeline_split.cc"
|
||||
"compile_cache_manager.cc"
|
||||
"parse/*.cc"
|
||||
"static_analysis/*.cc"
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -1351,21 +1351,25 @@ std::vector<ActionItem> GePipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> VmPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
std::vector<ActionItem> VmPipeline(const ResourcePtr &resource) {
|
||||
std::vector<ActionItem> actions;
|
||||
// If enable compilation cache and the cache is read successfully, only do the backend actions.
|
||||
if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) {
|
||||
actions = CommonPipeline();
|
||||
|
||||
// optimize
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
// optimize
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
|
||||
// Add opt-stage python pass stub
|
||||
(void)actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
|
||||
// Add opt-stage python pass stub
|
||||
(void)actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
|
||||
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
|
||||
// eliminate forward cnode for grad graph
|
||||
(void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
|
||||
// eliminate forward cnode for grad graph
|
||||
(void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
|
||||
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
}
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
if (ps::PSContext::instance()->is_worker()) {
|
||||
|
@ -1390,14 +1394,6 @@ std::vector<ActionItem> VmPipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> BackendPipeline() {
|
||||
std::vector<ActionItem> actions;
|
||||
// compile the ANF graph
|
||||
(void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
|
||||
// to execute the graph
|
||||
(void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
|
||||
return actions;
|
||||
}
|
||||
std::vector<ActionItem> MindIRPipeline() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
|
||||
|
@ -1415,8 +1411,12 @@ std::vector<ActionItem> MindIRPipeline() {
|
|||
(void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
|
||||
return actions;
|
||||
}
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
std::vector<ActionItem> ServerPipeline() {
|
||||
std::vector<ActionItem> ServerPipeline(const ResourcePtr &resource) {
|
||||
if (resource->EnableCompileCache() && resource->func_graph() != nullptr) {
|
||||
return {std::make_pair("server", StartServerAction)};
|
||||
}
|
||||
auto actions = CommonPipeline();
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
|
||||
|
@ -1424,7 +1424,10 @@ std::vector<ActionItem> ServerPipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> PServerPipeline() {
|
||||
std::vector<ActionItem> PServerPipeline(const ResourcePtr &resource) {
|
||||
if (resource->EnableCompileCache() && resource->func_graph() != nullptr) {
|
||||
return {std::make_pair("pserver", StartPSServerAction)};
|
||||
}
|
||||
auto actions = CommonPipeline();
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
|
@ -1433,7 +1436,10 @@ std::vector<ActionItem> PServerPipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> PSchedulerPipeline() {
|
||||
std::vector<ActionItem> PSchedulerPipeline(const ResourcePtr &resource) {
|
||||
if (resource->EnableCompileCache() && resource->func_graph() != nullptr) {
|
||||
return {std::make_pair("scheduler", StartPSSchedulerAction)};
|
||||
}
|
||||
auto actions = CommonPipeline();
|
||||
(void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
|
||||
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,12 +49,11 @@ bool StartServerAction(const ResourcePtr &res);
|
|||
bool DistributedSplitAction(const ResourcePtr &res);
|
||||
|
||||
std::vector<ActionItem> GePipeline();
|
||||
std::vector<ActionItem> VmPipeline();
|
||||
std::vector<ActionItem> VmPipeline(const ResourcePtr &resource);
|
||||
std::vector<ActionItem> MindIRPipeline();
|
||||
std::vector<ActionItem> BackendPipeline();
|
||||
std::vector<ActionItem> PServerPipeline();
|
||||
std::vector<ActionItem> ServerPipeline();
|
||||
std::vector<ActionItem> PSchedulerPipeline();
|
||||
std::vector<ActionItem> PServerPipeline(const ResourcePtr &resource);
|
||||
std::vector<ActionItem> ServerPipeline(const ResourcePtr &resource);
|
||||
std::vector<ActionItem> PSchedulerPipeline(const ResourcePtr &resource);
|
||||
abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
const abstract::AbstractBasePtrList &args_spec, bool clear = false);
|
||||
FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pipeline/jit/compile_cache_manager.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
#include "debug/common.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "debug/dump_proto.h"
|
||||
#include "utils/system/sha256.h"
|
||||
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
#include "ps/ps_context.h"
|
||||
#include "ps/core/node.h"
|
||||
#include "distributed/cluster/cluster_context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
namespace {
|
||||
constexpr char kCompileCacheSubDir[] = "graph_cache";
|
||||
constexpr char kCompileCacheFileName[] = "compile_cache";
|
||||
constexpr char kCompileCacheFileSuffix[] = ".mindir";
|
||||
constexpr char kDepFilesHashPath[] = "compile_dependency.hash";
|
||||
constexpr char kRoleServer[] = "server_";
|
||||
constexpr char kRolePServer[] = "pserver_";
|
||||
constexpr char kRolePScheduler[] = "pscheduler_";
|
||||
|
||||
std::string GetUserDefinedCachePath() {
|
||||
auto user_defined_path = MsContext::GetInstance()->get_param<std::string>(MS_CTX_COMPILE_CACHE_PATH);
|
||||
if (!user_defined_path.empty()) {
|
||||
user_defined_path += "/";
|
||||
return user_defined_path;
|
||||
}
|
||||
user_defined_path = common::GetEnv("MS_COMPILER_CACHE_PATH");
|
||||
if (!user_defined_path.empty()) {
|
||||
user_defined_path += "/";
|
||||
}
|
||||
return user_defined_path;
|
||||
}
|
||||
|
||||
std::string GetCompileCacheDir() {
|
||||
static const std::string user_defined_path = GetUserDefinedCachePath();
|
||||
static uint32_t rank_id = IsStandAlone() ? 0 : GetRank();
|
||||
static const std::string compile_cache_dir =
|
||||
user_defined_path + "rank_" + std::to_string(rank_id) + "/" + kCompileCacheSubDir;
|
||||
return compile_cache_dir;
|
||||
}
|
||||
|
||||
std::string GetRole() {
|
||||
#if ((defined ENABLE_CPU) && (!defined _WIN32))
|
||||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||
ps::PSContext::instance()->is_server()) {
|
||||
return kRoleServer;
|
||||
}
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
return kRolePServer;
|
||||
}
|
||||
if (ps::PSContext::instance()->is_scheduler()) {
|
||||
return kRolePScheduler;
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(INFO) << "Cluster is initialized. This node role is " << node->role();
|
||||
switch (node->role()) {
|
||||
case ps::core::NodeRole::SERVER:
|
||||
return kRolePServer;
|
||||
case ps::core::NodeRole::SCHEDULER:
|
||||
return kRolePScheduler;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return "";
|
||||
}
|
||||
|
||||
std::string GetCompileCachePath(size_t idx) {
|
||||
return GetCompileCacheDir() + "/" + GetRole() + kCompileCacheFileName + "_" + std::to_string(idx) +
|
||||
kCompileCacheFileSuffix;
|
||||
}
|
||||
|
||||
std::string GetDepFilesHashPath() {
|
||||
static const std::string dep_files_hash_path = GetCompileCacheDir() + "/" + GetRole() + kDepFilesHashPath;
|
||||
return dep_files_hash_path;
|
||||
}
|
||||
|
||||
std::string GetCompileDepFilesHash(const py::list &dep_files) {
|
||||
MS_LOG(DEBUG) << "Dependency files size: " << dep_files.size();
|
||||
std::vector<std::string> dep_files_path;
|
||||
for (auto dep_file : dep_files) {
|
||||
auto file_path = py::cast<std::string>(dep_file);
|
||||
MS_LOG(DEBUG) << "Dependency file path: " << file_path;
|
||||
(void)dep_files_path.emplace_back(file_path);
|
||||
}
|
||||
std::sort(dep_files_path.begin(), dep_files_path.end());
|
||||
std::string files_hash;
|
||||
for (const auto &path : dep_files_path) {
|
||||
std::string file_hash = system::sha256::GetHashFromFile(path);
|
||||
files_hash += file_hash;
|
||||
}
|
||||
return files_hash;
|
||||
}
|
||||
|
||||
bool CheckDepFilesHashConsistency(const std::string ¤t_dep_files_hash) {
|
||||
if (current_dep_files_hash.empty()) {
|
||||
MS_LOG(ERROR) << "Get current dependency files hash failed.";
|
||||
return false;
|
||||
}
|
||||
std::string dep_files_hash_path = GetDepFilesHashPath();
|
||||
auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
std::fstream input(realpath.value(), std::ios::in | std::ios::binary);
|
||||
if (!input) {
|
||||
MS_LOG(WARNING) << "Open the hash file " << realpath.value() << " failed. The file may not exist."
|
||||
<< ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
std::string checkpoint_hash;
|
||||
input >> checkpoint_hash;
|
||||
if (checkpoint_hash.empty()) {
|
||||
MS_LOG(ERROR) << "Get the compilation dependency files hash from " << realpath.value() << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (checkpoint_hash != current_dep_files_hash) {
|
||||
MS_LOG(WARNING) << "The compilation dependency files are changed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
|
||||
std::map<string, ValuePtr> ret{};
|
||||
for (auto weight = weights.begin(); weight != weights.end(); ++weight) {
|
||||
auto weight_name = py::cast<std::string>(weight->first);
|
||||
auto weight_value = parse::data_converter::PyDataToValue(py::cast<py::object>(weight->second));
|
||||
ret[weight_name] = weight_value;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::pair<FuncGraphPtr, LayoutMap> LoadFuncGraphFromMindIR(const py::dict &weights, bool has_parallel_info,
|
||||
size_t idx) {
|
||||
LayoutMap layout_map;
|
||||
std::string compile_cache_path = GetCompileCachePath(idx);
|
||||
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << compile_cache_path << " failed.";
|
||||
return std::make_pair(nullptr, layout_map);
|
||||
}
|
||||
std::ifstream f(realpath.value());
|
||||
bool file_is_good = f.good();
|
||||
f.close();
|
||||
if (!file_is_good) {
|
||||
MS_LOG(WARNING) << "Open the compilation cache file " << realpath.value() << " failed.";
|
||||
return std::make_pair(nullptr, layout_map);
|
||||
}
|
||||
MindIRLoader mindir_loader;
|
||||
mindir_loader.set_need_renormalize(false);
|
||||
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
|
||||
mindir_loader.set_has_parallel_info(has_parallel_info);
|
||||
auto fg = mindir_loader.LoadMindIR(realpath.value());
|
||||
return std::make_pair(fg, mindir_loader.layout_map());
|
||||
}
|
||||
|
||||
bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, size_t idx) {
|
||||
std::string compile_cache_path = GetCompileCachePath(idx);
|
||||
return DumpBinaryProto(fg, compile_cache_path, layout_fg);
|
||||
}
|
||||
|
||||
bool ExportDepFilesHash(const std::string &compile_cache_dep_files_hash) {
|
||||
std::string dep_files_hash_path = GetDepFilesHashPath();
|
||||
auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
ChangeFileMode(realpath.value(), S_IWUSR);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
fout << compile_cache_dep_files_hash;
|
||||
fout.close();
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CompileCacheManager::CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg) const {
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(ERROR) << "The func_graph to be cached is null.";
|
||||
return;
|
||||
}
|
||||
if (!ExportFuncGraphToMindIR(fg, layout_fg, compile_cache_id_)) {
|
||||
MS_LOG(ERROR) << "Failed to cache graph: " << fg->ToString();
|
||||
return;
|
||||
}
|
||||
if (compile_cache_id_ == 0 && !ExportDepFilesHash(compile_cache_dep_files_hash_)) {
|
||||
MS_LOG(ERROR) << "Failed to cache the dependency files hash";
|
||||
}
|
||||
}
|
||||
|
||||
void CompileCacheManager::InitCompileCacheHash(const py::list &compile_cache_dep_files) {
|
||||
compile_cache_dep_files_hash_ = GetCompileDepFilesHash(compile_cache_dep_files);
|
||||
}
|
||||
|
||||
FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr &manager, const py::dict &weights,
|
||||
const std::string &queue_name) {
|
||||
// Compare the dependency files hash.
|
||||
if (!CheckDepFilesHashConsistency(compile_cache_dep_files_hash_)) {
|
||||
MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Determine whether to load parallel information.
|
||||
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
bool has_parallel_info = false;
|
||||
if ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL)) {
|
||||
has_parallel_info = true;
|
||||
}
|
||||
// Load the compilation cache file.
|
||||
auto pair = LoadFuncGraphFromMindIR(weights, has_parallel_info, compile_cache_id_);
|
||||
if (pair.first == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
|
||||
return nullptr;
|
||||
}
|
||||
auto fg = pair.first;
|
||||
layout_map_ = pair.second;
|
||||
|
||||
MS_LOG(WARNING) << "Use the compilation cache and execute the backend actions only. Be aware of correctness risks.";
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(fg);
|
||||
fg->set_manager(manager);
|
||||
}
|
||||
// The value of attr "shared_name" will changed every time.
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
for (auto cnode : cnodes) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim != nullptr && prim->HasAttr("shared_name")) {
|
||||
prim->set_attr("shared_name", MakeValue(queue_name));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("cache_loaded_graph_" + std::to_string(compile_cache_id_) + ".ir", fg);
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_COMPILE_CACHE_MANAGER_H_
|
||||
#define MINDSPORE_CCSRC_PIPELINE_JIT_COMPILE_CACHE_MANAGER_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
namespace py = pybind11;
|
||||
// A class for loading and caching the func_graph.
|
||||
class CompileCacheManager {
|
||||
public:
|
||||
explicit CompileCacheManager(size_t compile_cache_id) : compile_cache_id_(compile_cache_id) {}
|
||||
|
||||
~CompileCacheManager() = default;
|
||||
|
||||
// Get the hash of dependent files when compiling graph.
|
||||
void InitCompileCacheHash(const py::list &compile_cache_dep_files);
|
||||
// Load the cached func_graph from mindir file.
|
||||
FuncGraphPtr GetCachedFuncGraph(const FuncGraphManagerPtr &manager, const py::dict &weights,
|
||||
const std::string &queue_name);
|
||||
// Export the func_graph to mindir file.
|
||||
void CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg) const;
|
||||
|
||||
const LayoutMap &layout_map() const { return layout_map_; }
|
||||
|
||||
private:
|
||||
size_t compile_cache_id_;
|
||||
std::string compile_cache_dep_files_hash_;
|
||||
LayoutMap layout_map_;
|
||||
};
|
||||
using CompileCacheManagerPtr = std::shared_ptr<CompileCacheManager>;
|
||||
} // namespace pipeline
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_COMPILE_CACHE_MANAGER_H_
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -59,7 +59,6 @@
|
|||
#include "backend/session/executor_manager.h"
|
||||
#include "runtime/hardware/device_context_manager.h"
|
||||
#include "runtime/device/kernel_runtime_manager.h"
|
||||
#include "utils/system/sha256.h"
|
||||
|
||||
#ifndef ENABLE_SECURITY
|
||||
#ifdef ENABLE_D
|
||||
|
@ -127,11 +126,6 @@ std::unordered_map<abstract::AbstractBasePtrList, uint64_t, abstract::AbstractBa
|
|||
g_args_cache;
|
||||
|
||||
namespace {
|
||||
constexpr char kCompileCacheSubDir[] = "graph_cache";
|
||||
constexpr char kCompileCacheFileName[] = "compile_cache";
|
||||
constexpr char kCompileCacheFileSuffix[] = ".mindir";
|
||||
constexpr char kDepFilesHashPath[] = "compile_dependency.hash";
|
||||
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
|
||||
std::ostringstream oss;
|
||||
|
@ -211,98 +205,6 @@ void SetLoopCount(const ResourcePtr &resource) {
|
|||
}
|
||||
}
|
||||
|
||||
std::string GetUserDefinedCachePath() {
|
||||
auto user_defined_path = MsContext::GetInstance()->get_param<std::string>(MS_CTX_COMPILE_CACHE_PATH);
|
||||
if (!user_defined_path.empty()) {
|
||||
user_defined_path += "/";
|
||||
return user_defined_path;
|
||||
}
|
||||
user_defined_path = common::GetEnv("MS_COMPILER_CACHE_PATH");
|
||||
if (!user_defined_path.empty()) {
|
||||
user_defined_path += "/";
|
||||
}
|
||||
return user_defined_path;
|
||||
}
|
||||
|
||||
std::string GetCompileCacheDir() {
|
||||
static const std::string user_defined_path = GetUserDefinedCachePath();
|
||||
static uint32_t rank_id = IsStandAlone() ? 0 : GetRank();
|
||||
static const std::string compile_cache_dir =
|
||||
user_defined_path + "rank_" + std::to_string(rank_id) + "/" + kCompileCacheSubDir;
|
||||
return compile_cache_dir;
|
||||
}
|
||||
|
||||
std::string GetCompileCachePath(size_t idx) {
|
||||
return GetCompileCacheDir() + "/" + kCompileCacheFileName + "_" + std::to_string(idx) + kCompileCacheFileSuffix;
|
||||
}
|
||||
|
||||
std::string GetDepFilesHashPath() {
|
||||
static const std::string dep_files_hash_path = GetCompileCacheDir() + "/" + kDepFilesHashPath;
|
||||
return dep_files_hash_path;
|
||||
}
|
||||
|
||||
size_t GetCompileCacheGraphId() {
|
||||
static size_t idx = 0;
|
||||
return idx++;
|
||||
}
|
||||
|
||||
std::string GetCompileDepFilesHash(const py::list &dep_files) {
|
||||
MS_LOG(DEBUG) << "Dependency files size: " << dep_files.size();
|
||||
std::vector<std::string> dep_files_path;
|
||||
for (auto dep_file : dep_files) {
|
||||
auto file_path = py::cast<std::string>(dep_file);
|
||||
MS_LOG(DEBUG) << "Dependency file path: " << file_path;
|
||||
(void)dep_files_path.emplace_back(file_path);
|
||||
}
|
||||
std::sort(dep_files_path.begin(), dep_files_path.end());
|
||||
std::string files_hash;
|
||||
for (const auto &path : dep_files_path) {
|
||||
std::string file_hash = system::sha256::GetHashFromFile(path);
|
||||
files_hash += file_hash;
|
||||
}
|
||||
return files_hash;
|
||||
}
|
||||
|
||||
bool CheckDepFilesHashConsistency(const std::string ¤t_dep_files_hash) {
|
||||
if (current_dep_files_hash.empty()) {
|
||||
MS_LOG(ERROR) << "Get current dependency files hash failed.";
|
||||
return false;
|
||||
}
|
||||
std::string dep_files_hash_path = GetDepFilesHashPath();
|
||||
auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
std::fstream input(realpath.value(), std::ios::in | std::ios::binary);
|
||||
if (!input) {
|
||||
MS_LOG(WARNING) << "Open the hash file " << realpath.value() << " failed. The file may not exist."
|
||||
<< ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
std::string checkpoint_hash;
|
||||
input >> checkpoint_hash;
|
||||
if (checkpoint_hash.empty()) {
|
||||
MS_LOG(ERROR) << "Get the compilation dependency files hash from " << realpath.value() << " failed.";
|
||||
return false;
|
||||
}
|
||||
if (checkpoint_hash != current_dep_files_hash) {
|
||||
MS_LOG(WARNING) << "The compilation dependency files are changed.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::map<string, ValuePtr> GenerateWeightsValueMap(const py::dict &weights) {
|
||||
std::map<string, ValuePtr> ret{};
|
||||
for (auto weight = weights.begin(); weight != weights.end(); ++weight) {
|
||||
auto weight_name = py::cast<std::string>(weight->first);
|
||||
auto weight_value = parse::data_converter::PyDataToValue(py::cast<py::object>(weight->second));
|
||||
ret[weight_name] = weight_value;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
|
||||
std::map<string, string> ret{};
|
||||
for (auto jit_param = jit_config.begin(); jit_param != jit_config.end(); ++jit_param) {
|
||||
|
@ -313,148 +215,6 @@ std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr LoadFuncGraphFromMindIR(const ResourcePtr &resource, const py::dict &weights, bool has_parallel_info) {
|
||||
const size_t idx = resource->compile_cache_id();
|
||||
std::string compile_cache_path = GetCompileCachePath(idx);
|
||||
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << compile_cache_path << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
std::ifstream f(realpath.value());
|
||||
bool file_is_good = f.good();
|
||||
f.close();
|
||||
if (!file_is_good) {
|
||||
MS_LOG(WARNING) << "Open the compilation cache file " << realpath.value() << " failed.";
|
||||
return nullptr;
|
||||
}
|
||||
MindIRLoader mindir_loader;
|
||||
mindir_loader.set_need_renormalize(false);
|
||||
mindir_loader.set_weights_value_map(GenerateWeightsValueMap(weights));
|
||||
mindir_loader.set_has_parallel_info(has_parallel_info);
|
||||
auto fg = mindir_loader.LoadMindIR(realpath.value());
|
||||
if (has_parallel_info) {
|
||||
resource->set_layout_map(mindir_loader.get_layout_map());
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetCachedFuncGraph(const ResourcePtr &resource, const py::dict &weights, const std::string &queue_name) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
// Compare the dependency files hash.
|
||||
if (!CheckDepFilesHashConsistency(resource->compile_cache_dep_files_hash())) {
|
||||
MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Determine whether to load parallel information.
|
||||
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
bool has_parallel_info = false;
|
||||
if ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL)) {
|
||||
has_parallel_info = true;
|
||||
}
|
||||
// Load the compilation cache file.
|
||||
FuncGraphPtr fg = LoadFuncGraphFromMindIR(resource, weights, has_parallel_info);
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions.";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
MS_LOG(WARNING) << "Use the compilation cache and execute the backend actions only. Be aware of correctness risks.";
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
auto res_mng = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(res_mng);
|
||||
res_mng->AddFuncGraph(fg);
|
||||
fg->set_manager(res_mng);
|
||||
}
|
||||
// The value of attr "shared_name" will changed every time.
|
||||
auto cnodes = fg->GetOrderedCnodes();
|
||||
for (auto cnode : cnodes) {
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (prim != nullptr && prim->HasAttr("shared_name")) {
|
||||
prim->set_attr("shared_name", MakeValue(queue_name));
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||
DumpIR("cache_loaded_graph_" + std::to_string(resource->compile_cache_id()) + ".ir", fg);
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, size_t idx) {
|
||||
std::string compile_cache_path = GetCompileCachePath(idx);
|
||||
auto realpath = Common::CreatePrefixPath(compile_cache_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << compile_cache_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
ChangeFileMode(realpath.value(), S_IRWXU);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
ModelProtoPtr fg_model = GetBinaryProto(fg, layout_fg);
|
||||
if (fg_model == nullptr) {
|
||||
MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed.";
|
||||
fout.close();
|
||||
return false;
|
||||
}
|
||||
if (!fg_model->SerializeToOstream(&fout)) {
|
||||
MS_LOG(ERROR) << "Failed to write the graph compilation cache to file " << realpath.value();
|
||||
fout.close();
|
||||
return false;
|
||||
}
|
||||
fout.close();
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ExportDepFilesHash(const ResourcePtr &resource) {
|
||||
std::string dep_files_hash_path = GetDepFilesHashPath();
|
||||
auto realpath = Common::CreatePrefixPath(dep_files_hash_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << dep_files_hash_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
ChangeFileMode(realpath.value(), S_IRWXU);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open cache file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
fout << resource->compile_cache_dep_files_hash();
|
||||
fout.close();
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
return true;
|
||||
}
|
||||
|
||||
void CacheFuncGraph(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto fg = resource->func_graph();
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(ERROR) << "The func_graph to be cached is null.";
|
||||
return;
|
||||
}
|
||||
FuncGraphPtr layout_fg = nullptr;
|
||||
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (fg->has_flag(parallel::AUTO_PARALLEL) &&
|
||||
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
|
||||
layout_fg = resource->GetResult(kStepParallelGraph).cast<FuncGraphPtr>();
|
||||
}
|
||||
if (!ExportFuncGraphToMindIR(fg, layout_fg, resource->compile_cache_id())) {
|
||||
MS_LOG(ERROR) << "Failed to cache graph: " << fg->ToString();
|
||||
return;
|
||||
}
|
||||
if (resource->compile_cache_id() == 0 && !ExportDepFilesHash(resource)) {
|
||||
MS_LOG(ERROR) << "Failed to cache the dependency files hash";
|
||||
}
|
||||
}
|
||||
|
||||
void RecordInitStatus() {
|
||||
static bool printed = false;
|
||||
if (!printed) {
|
||||
|
@ -951,14 +711,14 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
const std::string &server_mode = ps::PSContext::instance()->server_mode();
|
||||
if ((server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) &&
|
||||
ps::PSContext::instance()->is_server()) {
|
||||
return ServerPipeline();
|
||||
return ServerPipeline(resource);
|
||||
}
|
||||
if (ps::PSContext::instance()->is_server()) {
|
||||
resource->SetResult(kBackend, compile::CreateBackend());
|
||||
return PServerPipeline();
|
||||
return PServerPipeline(resource);
|
||||
}
|
||||
if (ps::PSContext::instance()->is_scheduler()) {
|
||||
return PSchedulerPipeline();
|
||||
return PSchedulerPipeline(resource);
|
||||
}
|
||||
if (distributed::cluster::ClusterContext::instance()->initialized()) {
|
||||
auto node = distributed::cluster::ClusterContext::instance()->node();
|
||||
|
@ -966,9 +726,9 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
MS_LOG(INFO) << "Cluster is initialized. This node role is " << node->role();
|
||||
switch (node->role()) {
|
||||
case ps::core::NodeRole::SERVER:
|
||||
return PServerPipeline();
|
||||
return PServerPipeline(resource);
|
||||
case ps::core::NodeRole::SCHEDULER:
|
||||
return PSchedulerPipeline();
|
||||
return PSchedulerPipeline(resource);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -977,18 +737,31 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
|
||||
if (use_vm && backend != "ge" && !is_air) {
|
||||
compile::SetMindRTEnable();
|
||||
// If enable compilation cache and the cache is read successfully, do the backend actions only.
|
||||
if (resource->enable_compile_cache() && resource->func_graph() != nullptr) {
|
||||
return BackendPipeline();
|
||||
}
|
||||
if (IsPhaseLoadFromMindIR(phase)) {
|
||||
return MindIRPipeline();
|
||||
}
|
||||
return VmPipeline();
|
||||
return VmPipeline(resource);
|
||||
}
|
||||
return GePipeline();
|
||||
}
|
||||
|
||||
void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase) {
|
||||
// The compilation cache only support for training currently.
|
||||
// If enable compilation cache, it will get a non-empty dependent files list from python.
|
||||
if (!IsPhaseTrain(phase) || compile_cache_dep_files_.empty()) {
|
||||
return;
|
||||
}
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t1 = GetTime();
|
||||
#endif
|
||||
static size_t idx = 0;
|
||||
resource->GetCompileCacheResource(compile_cache_dep_files_, weights_, queue_name_, idx++);
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t2 = GetTime();
|
||||
MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1);
|
||||
#endif
|
||||
}
|
||||
|
||||
bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple &args, const py::object &phase_obj,
|
||||
bool use_vm) {
|
||||
// Check if the phase is valid.
|
||||
|
@ -1005,29 +778,15 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
CheckArgsValid(args);
|
||||
|
||||
auto phase = py::cast<std::string>(phase_obj);
|
||||
phase_ = phase;
|
||||
MS_LOG(INFO) << "Start compiling, phase: " << phase;
|
||||
MS_LOG(DEBUG) << "source: {" << py::str(source_obj) << "}\nargs: " << py::str(const_cast<py::tuple &>(args));
|
||||
|
||||
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
|
||||
ResourcePtr resource = std::make_shared<Resource>(source_obj);
|
||||
|
||||
// If enable compilation cache, it will get a non-empty dependent files list from python.
|
||||
if (!compile_cache_dep_files_.empty()) {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t1 = GetTime();
|
||||
#endif
|
||||
resource->set_enable_compile_cache(true);
|
||||
resource->set_compile_cache_id(GetCompileCacheGraphId());
|
||||
resource->set_compile_cache_dep_files_hash(GetCompileDepFilesHash(compile_cache_dep_files_));
|
||||
resource->set_func_graph(GetCachedFuncGraph(resource, weights_, queue_name_));
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t2 = GetTime();
|
||||
MsProfile::StatTime("LoadCachedFuncGraph", t2 - t1);
|
||||
#endif
|
||||
}
|
||||
|
||||
phase_ = phase;
|
||||
InitCompileCacheInfo(resource, phase);
|
||||
ConfigManager::GetInstance().ResetQueue(queue_name_);
|
||||
|
||||
auto actions = GetPipeline(resource, phase, use_vm);
|
||||
std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(actions, phase));
|
||||
|
||||
|
@ -1063,7 +822,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
executor_info->arg_list_size = size;
|
||||
executor_info->resource = resource;
|
||||
info_[phase] = executor_info;
|
||||
pip->Run(phase);
|
||||
pip->Run();
|
||||
|
||||
// Save the compiled graph to MsPipeLine.
|
||||
SaveCompiledGraph(phase);
|
||||
|
@ -1160,17 +919,18 @@ bool GraphExecutorPy::Compile(const py::object &source_obj, const py::tuple &arg
|
|||
return ret_value;
|
||||
}
|
||||
|
||||
void CacheValidateFuncGraph(const std::string &phase, const ResourcePtr &resource) {
|
||||
if (resource->enable_compile_cache()) {
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t1 = GetTime();
|
||||
#endif
|
||||
CacheFuncGraph(resource);
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t2 = GetTime();
|
||||
MsProfile::StatTime("SaveCacheFuncGraph", t2 - t1);
|
||||
#endif
|
||||
void CacheValidateFuncGraph(const ResourcePtr &resource) {
|
||||
if (!resource->EnableCompileCache()) {
|
||||
return;
|
||||
}
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t1 = GetTime();
|
||||
#endif
|
||||
resource->CacheFuncGraph();
|
||||
#ifdef ENABLE_PROFILE
|
||||
double t2 = GetTime();
|
||||
MsProfile::StatTime("SaveCacheFuncGraph", t2 - t1);
|
||||
#endif
|
||||
}
|
||||
|
||||
void CheckInterpretNodeLineInfos() {
|
||||
|
@ -1249,12 +1009,12 @@ void SaveGraphForReadability(const std::string &action_name, const FuncGraphPtr
|
|||
}
|
||||
#endif
|
||||
|
||||
void Pipeline::Run(const std::string &phase) {
|
||||
void Pipeline::Run() {
|
||||
MS_LOG(INFO) << "Pipeline run";
|
||||
MS_EXCEPTION_IF_NULL(resource_);
|
||||
FuncGraphPtr user_graph = nullptr;
|
||||
|
||||
WITH(MsProfile::GetProfile())[&user_graph, &phase, this]() {
|
||||
WITH(MsProfile::GetProfile())[&user_graph, this]() {
|
||||
size_t i = 0;
|
||||
for (auto &action : actions_) {
|
||||
#ifdef ENABLE_TIMELINE
|
||||
|
@ -1271,7 +1031,7 @@ void Pipeline::Run(const std::string &phase) {
|
|||
SetLoopCount(resource_);
|
||||
} else if (action.first == "validate") {
|
||||
CheckInterpretNodeLineInfos();
|
||||
CacheValidateFuncGraph(phase, resource_);
|
||||
CacheValidateFuncGraph(resource_);
|
||||
#ifndef ENABLE_SECURITY
|
||||
#ifdef ENABLE_D
|
||||
FuncGraphPtr graph = resource_->func_graph();
|
||||
|
@ -1906,6 +1666,10 @@ void ClearResAtexit() {
|
|||
MS_LOG(INFO) << "Start clear parallel::entire_costgraph...";
|
||||
parallel::entire_costgraph.reset();
|
||||
MS_LOG(INFO) << "End clear parallel::entire_costgraph...";
|
||||
|
||||
MS_LOG(INFO) << "Start clear ProtobufLibrary...";
|
||||
google::protobuf::ShutdownProtobufLibrary();
|
||||
MS_LOG(INFO) << "End clear ProtobufLibrary...";
|
||||
}
|
||||
|
||||
py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -50,7 +50,7 @@ class Pipeline {
|
|||
|
||||
~Pipeline() = default;
|
||||
|
||||
void Run(const std::string &phase);
|
||||
void Run();
|
||||
|
||||
ResourcePtr resource() { return resource_; }
|
||||
|
||||
|
@ -143,6 +143,10 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
// 'validate' stage
|
||||
static std::vector<ActionItem> FilterActions(const std::vector<ActionItem> &actions, const std::string &phase);
|
||||
|
||||
void DelOneNetRes(const py::handle &py_phase);
|
||||
// If enable compile cache, get the compile cache resource.
|
||||
void InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase);
|
||||
|
||||
std::map<std::string, ExecutorInfoPtr> info_;
|
||||
static std::shared_ptr<GraphExecutorPy> executor_;
|
||||
static std::mutex instance_lock_;
|
||||
|
@ -159,7 +163,6 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
py::list compile_cache_dep_files_;
|
||||
py::dict weights_;
|
||||
std::map<PyObject *, AbstractBasePtr> cur_convert_input_;
|
||||
void DelOneNetRes(const py::handle &py_phase);
|
||||
};
|
||||
using GraphExecutorPyPtr = std::shared_ptr<GraphExecutorPy>;
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,6 +23,7 @@
|
|||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/parallel/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support opmap definition
|
||||
|
@ -322,6 +323,24 @@ Any Resource::GetAttrPtr(const TypeId &type, const std::string &name) {
|
|||
return GetMethodOrAttr(name, type_id, attr_map);
|
||||
}
|
||||
|
||||
void Resource::GetCompileCacheResource(const py::list &compile_cache_dep_files, const py::dict &weights,
|
||||
const std::string &queue_name, size_t compile_cache_id) {
|
||||
compile_cache_manager_ = std::make_shared<CompileCacheManager>(compile_cache_id);
|
||||
compile_cache_manager_->InitCompileCacheHash(compile_cache_dep_files);
|
||||
func_graph_ = compile_cache_manager_->GetCachedFuncGraph(manager_, weights, queue_name);
|
||||
layout_map_ = compile_cache_manager_->layout_map();
|
||||
}
|
||||
|
||||
void Resource::CacheFuncGraph() const {
|
||||
FuncGraphPtr layout_fg = nullptr;
|
||||
std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
|
||||
if (func_graph_->has_flag(parallel::AUTO_PARALLEL) &&
|
||||
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
|
||||
layout_fg = GetResult(kStepParallelGraph).cast<FuncGraphPtr>();
|
||||
}
|
||||
compile_cache_manager_->CacheFuncGraph(func_graph_, layout_fg);
|
||||
}
|
||||
|
||||
void Resource::Clean() {
|
||||
// AbstractTensor->elements() will be saved in AbstractBasePtrList
|
||||
args_spec_.clear();
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,6 +35,7 @@
|
|||
#include "pipeline/jit/static_analysis/prim.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "load_mindir/load_model.h"
|
||||
#include "pipeline/jit/compile_cache_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace pipeline {
|
||||
|
@ -88,19 +89,13 @@ class Resource : public ResourceBase {
|
|||
bool vm_loop_flag() const { return vm_loop_flag_; }
|
||||
int64_t loop_size() const { return loop_size_; }
|
||||
|
||||
void set_layout_map(const LayoutMap &layout_map) { layout_map_ = layout_map; }
|
||||
const LayoutMap &get_layout_map() const { return layout_map_; }
|
||||
const LayoutMap &layout_map() const { return layout_map_; }
|
||||
|
||||
bool enable_compile_cache() { return enable_compile_cache_; }
|
||||
void set_enable_compile_cache(bool enable_compile_cache) { enable_compile_cache_ = enable_compile_cache; }
|
||||
|
||||
size_t compile_cache_id() { return compile_cache_id_; }
|
||||
void set_compile_cache_id(size_t compile_cache_id) { compile_cache_id_ = compile_cache_id; }
|
||||
|
||||
const std::string &compile_cache_dep_files_hash() { return compile_cache_dep_files_hash_; }
|
||||
void set_compile_cache_dep_files_hash(const std::string &compile_cache_dep_files_hash) {
|
||||
compile_cache_dep_files_hash_ = compile_cache_dep_files_hash;
|
||||
}
|
||||
// Get the cached func_graph and parameters layout map.
|
||||
void GetCompileCacheResource(const py::list &compile_cache_dep_files, const py::dict &weights,
|
||||
const std::string &queue_name, size_t compile_cache_id);
|
||||
void CacheFuncGraph() const;
|
||||
bool EnableCompileCache() const { return compile_cache_manager_ != nullptr; }
|
||||
|
||||
// Reclaim resource and clear the cache.
|
||||
// GraphExecutorPy::Compile() can be called multiple times, so cache
|
||||
|
@ -119,10 +114,8 @@ class Resource : public ResourceBase {
|
|||
bool is_load_{false};
|
||||
bool vm_loop_flag_{false};
|
||||
int64_t loop_size_{1};
|
||||
bool enable_compile_cache_{false};
|
||||
size_t compile_cache_id_{0};
|
||||
std::string compile_cache_dep_files_hash_;
|
||||
LayoutMap layout_map_{};
|
||||
CompileCacheManagerPtr compile_cache_manager_{nullptr};
|
||||
};
|
||||
|
||||
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -40,7 +40,7 @@ class ResourceBase {
|
|||
|
||||
void SetResult(const std::string &key, const Any &value) { results_[key] = value; }
|
||||
|
||||
Any GetResult(const std::string &key) {
|
||||
Any GetResult(const std::string &key) const {
|
||||
auto iter = results_.find(key);
|
||||
if (iter == results_.end()) {
|
||||
MS_LOG(EXCEPTION) << "this key is not in resource list:" << key;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@
|
|||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <fstream>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/tensor.h"
|
||||
|
@ -94,7 +95,7 @@ class IrExporter {
|
|||
class IrExportBuilder {
|
||||
public:
|
||||
IrExportBuilder() : model_(std::make_shared<mind_ir::ModelProto>()) {}
|
||||
~IrExportBuilder() { google::protobuf::ShutdownProtobufLibrary(); }
|
||||
~IrExportBuilder() = default;
|
||||
std::string GetProtoString() const;
|
||||
void BuildModelInfo();
|
||||
bool BuildModel(const FuncGraphPtr &func_graph);
|
||||
|
@ -125,10 +126,8 @@ class IrExportBuilder {
|
|||
bool SetScalarToAttributeProtoForInt_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetTensorToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetSequenceToAttributeProto(const ValueSequencePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
bool SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string);
|
||||
bool SetSequenceToAttributeProto(const ValueSequencePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto);
|
||||
|
||||
mind_ir::TensorProto_DataType GetMindirDataType(TypeId type_id);
|
||||
mind_ir::TensorProto_DataType GetMindirDataBitsIntType(int bits);
|
||||
|
@ -866,15 +865,12 @@ bool IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A
|
|||
} else if (value->isa<Number>() || value->isa<TensorType>()) {
|
||||
return SetTypeToAttributeProto(value, attr_proto);
|
||||
} else if (value->isa<ValueSequence>()) {
|
||||
ResetTupleIndex();
|
||||
std::string seq_string = "scalar:";
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
if (!SetSequenceToAttributeProto(value->cast<ValueSequencePtr>(), attr_proto, &seq_string)) {
|
||||
if (!SetSequenceToAttributeProto(value->cast<ValueSequencePtr>(), attr_proto)) {
|
||||
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
attr_proto->set_ref_attr_name(seq_string);
|
||||
MS_LOG(DEBUG) << "Attr string: " << seq_string;
|
||||
attr_proto->set_ref_attr_name("Sequence");
|
||||
MS_LOG(DEBUG) << "Attr string: " << value->type_name();
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
return SetTensorToAttributeProto(value, attr_proto);
|
||||
} else if (value->isa<None>()) {
|
||||
|
@ -991,7 +987,7 @@ bool IrExportBuilder::SetTypeToAttributeProto_irs(const ValuePtr &value, mind_ir
|
|||
mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors();
|
||||
tensor_proto->set_data_type(mind_ir::TensorProto_DataType_BOOL);
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSORS);
|
||||
return SetTensorToAttributeProto(value, attr_proto);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name();
|
||||
|
@ -1018,9 +1014,6 @@ bool IrExportBuilder::SetScalarToAttributeProto_irs(const ValuePtr &value, mind_
|
|||
} else if (value->isa<FP64Imm>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_DOUBLE);
|
||||
attr_proto->add_doubles(GetValue<double>(value));
|
||||
} else if (value->isa<tensor::Tensor>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TENSOR);
|
||||
return SetTensorToAttributeProto(value, attr_proto);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported scalar type: " << value->type_name();
|
||||
return false;
|
||||
|
@ -1060,16 +1053,11 @@ bool IrExportBuilder::SetScalarToAttributeProtoForInt_irs(const ValuePtr &value,
|
|||
return true;
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr) {
|
||||
MS_LOG(ERROR) << "Value is nullptr";
|
||||
return false;
|
||||
}
|
||||
string value_name = "value" + std::to_string(GetTupleIndex());
|
||||
if (seq_string != nullptr) {
|
||||
*seq_string += value_name + ",";
|
||||
}
|
||||
if (value->isa<StringImm>() || value->isa<Scalar>()) {
|
||||
return SetScalarToAttributeProto_irs(value, attr_proto);
|
||||
}
|
||||
|
@ -1077,56 +1065,38 @@ bool IrExportBuilder::SetSeqElemToAttributeProto(const ValuePtr &value, mind_ir:
|
|||
}
|
||||
|
||||
bool IrExportBuilder::SetSequenceToAttributeProto(const ValueSequencePtr &value,
|
||||
mind_ir::AttributeProto *const attr_proto,
|
||||
std::string *const seq_string) {
|
||||
mind_ir::AttributeProto *const attr_proto) {
|
||||
if (value == nullptr || attr_proto == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "ValueSequencePtr or AttributeProto is null!";
|
||||
}
|
||||
if (value->isa<ValueTuple>() && seq_string != nullptr) {
|
||||
*seq_string += "Tuple[";
|
||||
const ValueTuplePtr &tuple_value = value->cast<ValueTuplePtr>();
|
||||
if (tuple_value->value().size() == 0) {
|
||||
*seq_string += "],";
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto tuple size is 0";
|
||||
return true;
|
||||
}
|
||||
for (const auto &item : tuple_value->value()) {
|
||||
if (item->isa<ValueTuple>()) {
|
||||
if (!SetSequenceToAttributeProto(item->cast<ValueTuplePtr>(), attr_proto, seq_string)) {
|
||||
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) {
|
||||
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
if (value->isa<ValueTuple>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_TUPLE);
|
||||
} else if (value->isa<ValueList>()) {
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_LIST);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The sequance value should be ValueTuple or ValueList, but it is " << value->ToString();
|
||||
}
|
||||
auto value_sequence = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_sequence);
|
||||
const auto &values = value_sequence->value();
|
||||
if (values.empty()) {
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto sequence size is 0";
|
||||
return true;
|
||||
}
|
||||
for (const auto &item : values) {
|
||||
mind_ir::AttributeProto *attr_values = attr_proto->add_values();
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (item->isa<ValueSequence>()) {
|
||||
if (!SetSequenceToAttributeProto(item->cast<ValueSequencePtr>(), attr_values)) {
|
||||
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!SetSeqElemToAttributeProto(item, attr_values)) {
|
||||
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
*seq_string += "],";
|
||||
} else if (value->isa<ValueList>() && seq_string != nullptr) {
|
||||
*seq_string += "List[";
|
||||
const ValueListPtr &list_value = value->cast<ValueListPtr>();
|
||||
if (list_value->value().size() == 0) {
|
||||
*seq_string += "],";
|
||||
MS_LOG(DEBUG) << "SetSequenceToAttributeProto list size is 0.";
|
||||
return true;
|
||||
}
|
||||
for (const auto &item : list_value->value()) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
if (item->isa<ValueList>()) {
|
||||
if (!SetSequenceToAttributeProto(item->cast<ValueListPtr>(), attr_proto, seq_string)) {
|
||||
MS_LOG(ERROR) << "Set sequence to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!SetSeqElemToAttributeProto(item, attr_proto, seq_string)) {
|
||||
MS_LOG(ERROR) << "Set seq elem to AttributeProto failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
*seq_string += "],";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -1145,9 +1115,35 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, const FuncGraphPtr ¶m_layout_fg) {
|
||||
bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path,
|
||||
const FuncGraphPtr ¶m_layout_fg) {
|
||||
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
|
||||
auto result = exporter->GetDumpProto(func_graph, param_layout_fg);
|
||||
return result;
|
||||
auto proto = exporter->GetDumpProto(func_graph, param_layout_fg);
|
||||
if (proto == nullptr) {
|
||||
MS_LOG(ERROR) << "Get binary proto for graph " << func_graph->ToString() << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto realpath = Common::CreatePrefixPath(file_path, true);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path of file " << file_path << " failed.";
|
||||
return false;
|
||||
}
|
||||
|
||||
ChangeFileMode(realpath.value(), S_IWUSR);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open the file '" << realpath.value() << "' failed!" << ErrnoToString(errno);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!proto->SerializeToOstream(&fout)) {
|
||||
MS_LOG(ERROR) << "Failed to write the mindir proto to file " << realpath.value();
|
||||
fout.close();
|
||||
return false;
|
||||
}
|
||||
fout.close();
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,12 +49,13 @@ enum ParseForm : int {
|
|||
FORM_PARSE_TENSOR = 2,
|
||||
FORM_PARSE_NONE = 3,
|
||||
FORM_PARSE_MONAD = 4,
|
||||
FORM_PARSE_UNDEFINE = 5,
|
||||
FORM_PARSE_SEQUENCE = 5,
|
||||
FORM_PARSE_UNDEFINE = 6,
|
||||
};
|
||||
|
||||
static std::map<std::string, ParseForm> kParseTypeSwitchMap{
|
||||
{"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR},
|
||||
{"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD}, {"", FORM_PARSE_UNDEFINE}};
|
||||
{"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD}, {"Sequence", FORM_PARSE_SEQUENCE}};
|
||||
|
||||
static mindspore::HashMap<int, TypeId> kDefaultValueSwitchMap{
|
||||
{mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool},
|
||||
|
@ -226,17 +227,13 @@ ValuePtr ParseAttrInSingleScalar_double_double(const mind_ir::AttributeProto &at
|
|||
return MakeValue<double>(value);
|
||||
}
|
||||
|
||||
string GetTypeString(const std::string &ref_attr_name, size_t *pos) {
|
||||
if ((*pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
return ref_attr_name.substr(*pos, string("scalar:").length() - 1);
|
||||
} else if ((*pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
return ref_attr_name.substr(*pos, string("type:").length() - 1);
|
||||
} else if ((*pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
return ref_attr_name.substr(*pos, string("tensor:").length() - 1);
|
||||
} else if (ref_attr_name == "none") {
|
||||
return ref_attr_name;
|
||||
ParseForm GetParseFormType(const std::string &ref_attr_name) {
|
||||
for (const auto &iter : kParseTypeSwitchMap) {
|
||||
if (ref_attr_name.find(iter.first) == 0) {
|
||||
return iter.second;
|
||||
}
|
||||
}
|
||||
return "";
|
||||
return FORM_PARSE_UNDEFINE;
|
||||
}
|
||||
} // namespace
|
||||
tensor::TensorPtr MSANFModelParser::GenerateTensorPtrFromTensorProto(const mind_ir::TensorProto &attr_tensor,
|
||||
|
@ -600,13 +597,13 @@ ValuePtr MSANFModelParser::ParseAttrInScalarForm(const mind_ir::AttributeProto &
|
|||
const int attr_tensor_type = attr_proto.tensors(index).data_type();
|
||||
if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type;
|
||||
return {};
|
||||
return nullptr;
|
||||
}
|
||||
return TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]);
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_type;
|
||||
return {};
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -694,11 +691,9 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
|
|||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
std::size_t pos(0);
|
||||
string type = GetTypeString(ref_attr_name, &pos);
|
||||
|
||||
ParseForm type = GetParseFormType(ref_attr_name);
|
||||
mindspore::HashMap<std::string, ValuePtr> multi_value_map;
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
switch (type) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
ObtainCNodeAttrInTypeForm(prim, attr_proto);
|
||||
break;
|
||||
|
@ -740,13 +735,22 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const mind
|
|||
prim->AddAttr(attr_name, kNone);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_SEQUENCE: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get sequence value for " << attr_name;
|
||||
return false;
|
||||
}
|
||||
prim->AddAttr(attr_name, sequence_value);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
|
||||
if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
|
||||
if (type == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
|
||||
if (ref_attr_name.find("Tuple") != std::string::npos) {
|
||||
auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
|
||||
prim->AddAttr(attr_name, value_tuple_ptr);
|
||||
} else {
|
||||
|
@ -843,6 +847,67 @@ bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_
|
|||
return true;
|
||||
}
|
||||
|
||||
ValuePtr MSANFModelParser::ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto) {
|
||||
std::vector<ValuePtr> vec;
|
||||
for (int i = 0; i < attr_proto.values_size(); ++i) {
|
||||
mind_ir::AttributeProto elem_attr_proto = attr_proto.values(i);
|
||||
switch (elem_attr_proto.type()) {
|
||||
case mind_ir::AttributeProto_AttributeType_TENSORS: {
|
||||
mind_ir::TensorProto tensor_proto = elem_attr_proto.tensors(0);
|
||||
if (tensor_proto.has_raw_data()) {
|
||||
// For real tensor.
|
||||
tensor::TensorPtr tensor_info = GenerateTensorPtrFromTensorProto(tensor_proto);
|
||||
if (tensor_info == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get the tensor for ValueNode.";
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(tensor_info);
|
||||
} else {
|
||||
// For data type.
|
||||
const int attr_tensor_type = tensor_proto.data_type();
|
||||
auto iter = kDefaultValueSwitchMap.find(attr_tensor_type);
|
||||
if (iter == kDefaultValueSwitchMap.end()) {
|
||||
MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type;
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(TypeIdToType(iter->second));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case mind_ir::AttributeProto_AttributeType_TUPLE:
|
||||
case mind_ir::AttributeProto_AttributeType_LIST: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(elem_attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get the sequence value";
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(sequence_value);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
// For string and scalar.
|
||||
auto scalar_value = ParseAttrInScalarForm(elem_attr_proto, 0);
|
||||
if (scalar_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get the scalar for ValueNode.";
|
||||
return nullptr;
|
||||
}
|
||||
(void)vec.emplace_back(scalar_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto type = attr_proto.type();
|
||||
ValuePtr value_sequence;
|
||||
if (type == mind_ir::AttributeProto_AttributeType_TUPLE) {
|
||||
value_sequence = std::make_shared<ValueTuple>(vec);
|
||||
} else if (type == mind_ir::AttributeProto_AttributeType_LIST) {
|
||||
value_sequence = std::make_shared<ValueList>(vec);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The attribute type should be tuple or list, but it is " << type;
|
||||
}
|
||||
|
||||
return value_sequence;
|
||||
}
|
||||
|
||||
bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name,
|
||||
const mind_ir::AttributeProto &attr_proto) {
|
||||
if (!attr_proto.has_ref_attr_name()) {
|
||||
|
@ -850,23 +915,10 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_na
|
|||
return false;
|
||||
}
|
||||
const std::string &ref_attr_name = attr_proto.ref_attr_name();
|
||||
string type = "";
|
||||
std::size_t pos;
|
||||
if ((pos = ref_attr_name.find("scalar:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("scalar:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("type:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("type:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("tensor:").length() - 1);
|
||||
} else if ((pos = ref_attr_name.find("Monad:")) != std::string::npos) {
|
||||
type = ref_attr_name.substr(pos, string("Monad:").length() - 1);
|
||||
} else if (ref_attr_name == "none") {
|
||||
type = ref_attr_name;
|
||||
}
|
||||
|
||||
ParseForm type = GetParseFormType(ref_attr_name);
|
||||
ValueNodePtr new_value_node;
|
||||
mindspore::HashMap<std::string, ValuePtr> multi_value_map;
|
||||
switch (kParseTypeSwitchMap[type]) {
|
||||
switch (type) {
|
||||
case FORM_PARSE_TYPE: {
|
||||
ObtainValueNodeInTypeForm(value_node_name, attr_proto.tensors(0));
|
||||
break;
|
||||
|
@ -879,6 +931,7 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_na
|
|||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
}
|
||||
// Compatible with old versions.
|
||||
if (ref_attr_name.find("Tuple[]") != std::string::npos) {
|
||||
MS_LOG(INFO) << "Build Tuple() ValueNode for primitive.";
|
||||
ValuePtr res = MakeValue(std::vector<ValuePtr>{});
|
||||
|
@ -907,13 +960,24 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_na
|
|||
ObtainValueNodeInMonadForm(value_node_name, attr_proto);
|
||||
break;
|
||||
}
|
||||
case FORM_PARSE_SEQUENCE: {
|
||||
auto sequence_value = ObtainValueInSequenceForm(attr_proto);
|
||||
if (sequence_value == nullptr) {
|
||||
MS_LOG(ERROR) << "Failed to get sequence value for " << value_node_name;
|
||||
return false;
|
||||
}
|
||||
new_value_node = NewValueNode(sequence_value);
|
||||
new_value_node->set_abstract(sequence_value->ToAbstract());
|
||||
anfnode_build_map_[value_node_name] = new_value_node;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (kParseTypeSwitchMap[type] == FORM_PARSE_SCALAR && multi_value_map.size() != 0) {
|
||||
if ((pos = ref_attr_name.find("Tuple")) != std::string::npos) {
|
||||
// Compatible with old versions.
|
||||
if (type == FORM_PARSE_SCALAR && !multi_value_map.empty()) {
|
||||
if (ref_attr_name.find("Tuple") != std::string::npos) {
|
||||
auto value_tuple_ptr = ParserScalarAttrValue<ValueTuple>(ref_attr_name, multi_value_map);
|
||||
new_value_node = NewValueNode(value_tuple_ptr);
|
||||
new_value_node->set_abstract(value_tuple_ptr->ToAbstract());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -86,6 +86,7 @@ class MSANFModelParser {
|
|||
bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor);
|
||||
bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto);
|
||||
ValuePtr ObtainValueInSequenceForm(const mind_ir::AttributeProto &attr_proto);
|
||||
bool little_endian() { return little_endian_; }
|
||||
mindspore::HashMap<std::string, abstract::AbstractBasePtr> GetAbstractForNode(
|
||||
const mind_ir::AttributeProto &attr_proto);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -69,7 +69,7 @@ class MindIRLoader {
|
|||
void set_weights_value_map(const std::map<string, ValuePtr> &weights_value_map) {
|
||||
weights_value_map_ = weights_value_map;
|
||||
}
|
||||
const LayoutMap &get_layout_map() { return layout_map_; }
|
||||
const LayoutMap &layout_map() { return layout_map_; }
|
||||
FuncGraphPtr LoadMindIR(const void *buffer, const size_t &size);
|
||||
FuncGraphPtr LoadMindIR(const std::string &file_name);
|
||||
std::vector<FuncGraphPtr> LoadMindIRs(const std::vector<std::string> file_names);
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as CV
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.dataset.vision import Inter
|
||||
from mindspore.nn.metrics import Accuracy
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import LossMonitor
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
DATASET_PATH = "/home/workspace/mindspore_dataset/mnist"
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_compile_cache=True, compile_cache_path=sys.argv[1])
|
||||
context.set_ps_context(enable_ps=True)
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
"""weight initial for conv layer"""
|
||||
weight = weight_variable()
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
weight_init=weight, has_bias=False, pad_mode="valid")
|
||||
|
||||
|
||||
def fc_with_initialize(input_channels, out_channels):
|
||||
"""weight initial for fc layer"""
|
||||
weight = weight_variable()
|
||||
bias = weight_variable()
|
||||
return nn.Dense(input_channels, out_channels, weight, bias)
|
||||
|
||||
|
||||
def weight_variable():
|
||||
"""weight initial"""
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
def __init__(self, num_class=10, channel=1):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = conv(channel, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
self.relu = nn.ReLU()
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flatten = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_dataset(data_path, batch_size=32, repeat_size=1,
|
||||
num_parallel_workers=1):
|
||||
"""
|
||||
create dataset for train or test
|
||||
"""
|
||||
# define dataset
|
||||
mnist_ds = ds.MnistDataset(data_path)
|
||||
|
||||
resize_height, resize_width = 32, 32
|
||||
rescale = 1.0 / 255.0
|
||||
shift = 0.0
|
||||
rescale_nml = 1 / 0.3081
|
||||
shift_nml = -1 * 0.1307 / 0.3081
|
||||
|
||||
# define map operations
|
||||
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
|
||||
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
|
||||
rescale_op = CV.Rescale(rescale, shift)
|
||||
hwc2chw_op = CV.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
||||
# apply map operations on images
|
||||
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers)
|
||||
|
||||
# apply DatasetOps
|
||||
buffer_size = 10000
|
||||
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
|
||||
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
|
||||
mnist_ds = mnist_ds.repeat(repeat_size)
|
||||
|
||||
return mnist_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
network = LeNet5(10)
|
||||
network.set_param_ps()
|
||||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
|
||||
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
|
||||
|
||||
ds_train = create_dataset(os.path.join(DATASET_PATH, "train"), 32, 1)
|
||||
model.train(1, ds_train, callbacks=[LossMonitor()], dataset_sink_mode=False)
|
||||
|
||||
ds_eval = create_dataset(os.path.join(DATASET_PATH, "test"), 32, 1)
|
||||
acc = model.eval(ds_eval, dataset_sink_mode=False)
|
||||
|
||||
print("Accuracy:", acc['Accuracy'])
|
||||
assert acc['Accuracy'] > 0.83
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -53,8 +53,8 @@ def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_
|
|||
array_shape_first = np.array([int(x) for x in shape_first])
|
||||
|
||||
# Second run with compile cache
|
||||
cmd_second = cmd_first = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_second +\
|
||||
" 2>&1"
|
||||
cmd_second = f"GLOG_v=2 python " + file_name + " '" + cache_path + "' > " + log_file_name_second + \
|
||||
" 2>&1"
|
||||
subprocess.check_output(cmd_second, shell=True)
|
||||
assert os.path.exists(log_file_name_second)
|
||||
with open(log_file_name_second, "r") as f_second:
|
||||
|
@ -108,6 +108,89 @@ def run_twice_with_different_networks(file_name_first, file_name_second, cache_p
|
|||
os.remove(log_file_name_second)
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
|
||||
def check_log(role, log_name, str_to_check):
|
||||
assert os.path.exists(role + "/" + log_name)
|
||||
with open(role + "/" + log_name, "r") as f:
|
||||
data = f.read()
|
||||
assert str_to_check in data
|
||||
|
||||
|
||||
def start_ps_subprocess(script_path, cache_path, str_to_check, log_name):
|
||||
cwd = os.getcwd()
|
||||
# start sched first time.
|
||||
os.environ['MS_ROLE'] = 'MS_SCHED'
|
||||
cmd_first = f"cd " + cwd + "/sched && GLOG_v=2 python ../" + script_path + " ../" + cache_path + " > " \
|
||||
+ log_name + " 2>&1 &"
|
||||
subprocess.run(cmd_first, shell=True)
|
||||
# start server first time.
|
||||
os.environ['MS_ROLE'] = 'MS_PSERVER'
|
||||
cmd_first = f"cd " + cwd + "/server && GLOG_v=2 python ../" + script_path + " ../" + cache_path + " > " \
|
||||
+ log_name + " 2>&1 &"
|
||||
subprocess.run(cmd_first, shell=True)
|
||||
# start worker first time.
|
||||
os.environ['MS_ROLE'] = 'MS_WORKER'
|
||||
cmd_first = f"cd " + cwd + "/worker && GLOG_v=2 python ../" + script_path + " ../" + cache_path + " > " \
|
||||
+ log_name + " 2>&1"
|
||||
subprocess.run(cmd_first, shell=True, check=True)
|
||||
os.chdir(cwd)
|
||||
check_log("sched", log_name, str_to_check)
|
||||
check_log("server", log_name, str_to_check)
|
||||
check_log("worker", log_name, str_to_check)
|
||||
|
||||
|
||||
def clear_and_make_run_dir(dir_path):
|
||||
shutil.rmtree(dir_path, ignore_errors=True)
|
||||
assert not os.path.exists(dir_path)
|
||||
os.mkdir(dir_path)
|
||||
assert os.path.exists(dir_path)
|
||||
|
||||
|
||||
def check_compile_cache_files(cache_path, role):
|
||||
assert os.path.exists(cache_path)
|
||||
assert os.path.exists(cache_path + "/rank_0/graph_cache/" + role + "compile_cache_0.mindir")
|
||||
assert os.path.exists(cache_path + "/rank_0/graph_cache/" + role + "compile_dependency.hash")
|
||||
|
||||
|
||||
def run_lenet_ps_twice(file_name, cache_path, log_file_name_first, log_file_name_second):
|
||||
# Clear compile cache folder and log files
|
||||
shutil.rmtree(cache_path, ignore_errors=True)
|
||||
assert not os.path.exists(cache_path)
|
||||
clear_and_make_run_dir("sched")
|
||||
clear_and_make_run_dir("server")
|
||||
clear_and_make_run_dir("worker")
|
||||
# Set envs
|
||||
os.environ['MS_SCHED_HOST'] = '127.0.0.1'
|
||||
os.environ['MS_SCHED_PORT'] = '8182'
|
||||
os.environ['MS_SCHED_NUM'] = '1'
|
||||
os.environ['MS_SERVER_NUM'] = '1'
|
||||
os.environ['MS_WORKER_NUM'] = '1'
|
||||
# First run
|
||||
first_str_to_check = "Check the consistency of dependency files hash failed. Execute all the compilation actions."
|
||||
start_ps_subprocess(file_name, cache_path, first_str_to_check, log_file_name_first)
|
||||
assert os.path.exists(cache_path)
|
||||
check_compile_cache_files(cache_path, "")
|
||||
check_compile_cache_files(cache_path, "pserver_")
|
||||
check_compile_cache_files(cache_path, "pscheduler_")
|
||||
# Second run
|
||||
os.environ['MS_SCHED_PORT'] = '8183'
|
||||
second_str_to_check = "Use the compilation cache and execute the backend actions only. Be aware of correctness" \
|
||||
" risks."
|
||||
start_ps_subprocess(file_name, cache_path, second_str_to_check, log_file_name_second)
|
||||
|
||||
# Clear
|
||||
del os.environ['MS_SCHED_HOST']
|
||||
del os.environ['MS_SCHED_PORT']
|
||||
del os.environ['MS_ROLE']
|
||||
del os.environ['MS_SCHED_NUM']
|
||||
del os.environ['MS_SERVER_NUM']
|
||||
del os.environ['MS_WORKER_NUM']
|
||||
shutil.rmtree("sched", ignore_errors=True)
|
||||
shutil.rmtree("server", ignore_errors=True)
|
||||
shutil.rmtree("worker", ignore_errors=True)
|
||||
shutil.rmtree(cache_path, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -160,3 +243,37 @@ def test_compile_cache_auto_detect():
|
|||
"""
|
||||
run_twice_with_different_networks("run_lenet.py", "run_network_with_weights.py", "./lenet_auto_detect",
|
||||
"auto_detect_first.txt", "auto_detect_second.txt")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_lenet_change_dir():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the regular compile cache function can run successfully when changing
|
||||
the current work directory.
|
||||
Expectation: success.
|
||||
"""
|
||||
cwd = os.getcwd()
|
||||
new_path = cwd + '/tmp'
|
||||
shutil.rmtree(new_path, ignore_errors=True)
|
||||
os.mkdir(new_path)
|
||||
os.chdir(new_path)
|
||||
run_twice_with_same_network("../run_lenet.py", "../lenet_change_dir", "../lenet_change_dir_first.txt",
|
||||
"../lenet_change_dir_second.txt")
|
||||
shutil.rmtree(new_path, ignore_errors=True)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_compile_cache_lenet_ps():
|
||||
"""
|
||||
Feature: Compile cache.
|
||||
Description: Test whether the regular compile cache function can run successfully with lenet in ps mode.
|
||||
Expectation: success.
|
||||
"""
|
||||
run_lenet_ps_twice("run_lenet_ps.py", "./lenet_ps", "lenet_ps_first.txt", "lenet_ps_second.txt")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,12 +24,10 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return "";
|
|||
|
||||
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
||||
return "";
|
||||
}
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
ModelProtoPtr GetBinaryProto(const FuncGraphPtr &func_graph, const FuncGraphPtr &layout_fg) {
|
||||
ModelProtoPtr empty_model;
|
||||
return empty_model;
|
||||
bool DumpBinaryProto(const FuncGraphPtr &func_graph, const std::string &file_path,
|
||||
const FuncGraphPtr ¶m_layout_fg) {
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue