forked from mindspore-Ecosystem/mindspore
!17173 Add frontend compile cache
From: @ginfung Reviewed-by: @hwhewei,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
aa2093fb7d
|
@ -19,6 +19,7 @@
|
|||
#include <string>
|
||||
|
||||
#include "ir/func_graph.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
|
||||
|
@ -27,6 +28,8 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
|
|||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
|
||||
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph);
|
||||
|
||||
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -836,6 +836,15 @@ std::vector<ActionItem> VmPipeline() {
|
|||
return actions;
|
||||
}
|
||||
|
||||
std::vector<ActionItem> BackendPipeline() {
|
||||
std::vector<ActionItem> actions;
|
||||
// compile the ANF graph
|
||||
actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
|
||||
// to execute the graph
|
||||
actions.emplace_back(std::make_pair("execute", ExecuteAction));
|
||||
return actions;
|
||||
}
|
||||
|
||||
#if (ENABLE_CPU && !_WIN32)
|
||||
std::vector<ActionItem> ServerPipeline() {
|
||||
auto actions = CommonPipeline();
|
||||
|
|
|
@ -49,6 +49,7 @@ bool StartServerAction(const ResourcePtr &res);
|
|||
|
||||
std::vector<ActionItem> GePipeline();
|
||||
std::vector<ActionItem> VmPipeline();
|
||||
std::vector<ActionItem> BackendPipeline();
|
||||
std::vector<ActionItem> PServerPipeline();
|
||||
std::vector<ActionItem> ServerPipeline();
|
||||
std::vector<ActionItem> PSchedulerPipeline();
|
||||
|
|
|
@ -101,6 +101,8 @@ std::unordered_map<abstract::AbstractBasePtrList, int64_t, abstract::AbstractBas
|
|||
g_args_cache;
|
||||
|
||||
namespace {
|
||||
constexpr char kCompileCacheFilePath[] = "compile_cache.mindir";
|
||||
|
||||
std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
|
||||
std::ostringstream oss;
|
||||
int spaces = 2;
|
||||
|
@ -155,21 +157,71 @@ std::string GetCompileExceptionInfo() {
|
|||
return oss.str();
|
||||
}
|
||||
|
||||
void SetGpuLoopSink(const ResourcePtr &resource_) {
|
||||
auto func_graph = resource_->func_graph();
|
||||
void SetGpuLoopSink(const ResourcePtr &resource) {
|
||||
auto func_graph = resource->func_graph();
|
||||
if (func_graph != nullptr && func_graph->manager() != nullptr) {
|
||||
auto manager = func_graph->manager();
|
||||
size_t graph_nums = manager->func_graphs().size();
|
||||
int64_t sinksize = ConfigManager::GetInstance().iter_num();
|
||||
if (graph_nums == 1) {
|
||||
resource_->set_gpu_loopsink(true, sinksize);
|
||||
resource->set_gpu_loopsink(true, sinksize);
|
||||
} else {
|
||||
resource_->set_gpu_loopsink(false, sinksize);
|
||||
resource->set_gpu_loopsink(false, sinksize);
|
||||
}
|
||||
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to "
|
||||
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource->gpu_loopsink_flag() << ", set loopsink size to "
|
||||
<< sinksize;
|
||||
}
|
||||
}
|
||||
|
||||
void GetCachedFuncGraph(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto realpath = Common::GetRealPath(kCompileCacheFilePath);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
|
||||
}
|
||||
std::ifstream f(realpath.value());
|
||||
bool cache_file_existed = f.good();
|
||||
f.close();
|
||||
if (!cache_file_existed) {
|
||||
MS_LOG(WARNING) << "The compilation cache file '" << realpath.value()
|
||||
<< "' dose not exist. Execute all the compilation actions.";
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
|
||||
FuncGraphPtr fg = LoadMindIR(realpath.value());
|
||||
if (fg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
|
||||
}
|
||||
FuncGraphManagerPtr mng = fg->manager();
|
||||
if (mng == nullptr) {
|
||||
auto res_mng = resource->manager();
|
||||
MS_EXCEPTION_IF_NULL(res_mng);
|
||||
res_mng->AddFuncGraph(fg);
|
||||
fg->set_manager(res_mng);
|
||||
}
|
||||
resource->set_func_graph(fg);
|
||||
}
|
||||
|
||||
void CacheFuncGraph(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
auto realpath = Common::GetRealPath(kCompileCacheFilePath);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
|
||||
}
|
||||
|
||||
ChangeFileMode(realpath.value(), S_IRWXU);
|
||||
std::ofstream fout(realpath.value());
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!";
|
||||
}
|
||||
FuncGraphPtr fg = resource->func_graph();
|
||||
mind_ir::ModelProto fg_model = GetBinaryProto(fg);
|
||||
if (!fg_model.SerializeToOstream(&fout)) {
|
||||
MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value();
|
||||
}
|
||||
fout.close();
|
||||
ChangeFileMode(realpath.value(), S_IRUSR);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
|
||||
|
@ -556,6 +608,11 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
|
|||
// Connect session to debugger
|
||||
backend_ptr->SetDebugger();
|
||||
resource->results()[kBackend] = backend_ptr;
|
||||
// If the 'use_frontend_compile_cache' context has been set true and the cache is read successfully,
|
||||
// do the backend actions only.
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE) && resource->func_graph() != nullptr) {
|
||||
return BackendPipeline();
|
||||
}
|
||||
return VmPipeline();
|
||||
}
|
||||
return GePipeline();
|
||||
|
@ -583,6 +640,10 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
|
|||
MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!";
|
||||
ResourcePtr resource = std::make_shared<Resource>(obj);
|
||||
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE)) {
|
||||
GetCachedFuncGraph(resource);
|
||||
}
|
||||
|
||||
auto p_actions = GetPipeline(resource, phase_s, use_vm);
|
||||
std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(p_actions, phase_s));
|
||||
|
||||
|
@ -711,6 +772,10 @@ void Pipeline::Run() {
|
|||
};
|
||||
if (action.first == "task_emit") {
|
||||
SetGpuLoopSink(resource_);
|
||||
} else if (action.first == "validate") {
|
||||
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_COMPILE_CACHE)) {
|
||||
CacheFuncGraph(resource_);
|
||||
}
|
||||
}
|
||||
if (!result) {
|
||||
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
|
||||
|
|
|
@ -98,7 +98,9 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
|
|||
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
|
||||
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
|
||||
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
|
||||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
|
||||
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
|
||||
.value("save_compile_cache", MsCtxParam::MS_CTX_SAVE_COMPILE_CACHE)
|
||||
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE);
|
||||
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
|
||||
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
|
||||
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")
|
||||
|
|
|
@ -71,6 +71,7 @@ class IrExporter {
|
|||
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
|
||||
virtual ~IrExporter() = default;
|
||||
std::string GetDumpString(const FuncGraphPtr &func_graph);
|
||||
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph);
|
||||
|
||||
private:
|
||||
IrExportBuilderPtr builder_;
|
||||
|
@ -83,6 +84,7 @@ class IrExportBuilder {
|
|||
std::string GetProtoString(const FuncGraphPtr &func_graph);
|
||||
void BuildModelInfo();
|
||||
void BuildModel(const FuncGraphPtr &func_graph);
|
||||
mind_ir::ModelProto Model() { return model_; }
|
||||
|
||||
private:
|
||||
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
|
||||
|
@ -146,6 +148,20 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
|
|||
return builder_->GetProtoString(func_graph);
|
||||
}
|
||||
|
||||
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
|
||||
if ((builder_ == nullptr) || (func_graph == nullptr)) {
|
||||
MS_LOG(EXCEPTION) << "Input params is null.";
|
||||
}
|
||||
|
||||
// Export model info
|
||||
builder_->BuildModelInfo();
|
||||
|
||||
// Export model and return string
|
||||
builder_->BuildModel(func_graph);
|
||||
|
||||
return builder_->Model();
|
||||
}
|
||||
|
||||
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
|
||||
MS_LOG(DEBUG) << "BuildModel complete!";
|
||||
return model_.SerializeAsString();
|
||||
|
@ -729,4 +745,9 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
return exporter->GetDumpString(func_graph);
|
||||
}
|
||||
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph) {
|
||||
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
|
||||
return exporter->GetDumpProto(func_graph);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -513,7 +513,8 @@ def _check_target_specific_cfgs(device, arg_key):
|
|||
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
|
||||
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
|
||||
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
|
||||
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str)
|
||||
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
|
||||
save_compile_cache=bool, load_compile_cache=bool)
|
||||
def set_context(**kwargs):
|
||||
"""
|
||||
Set context for running environment.
|
||||
|
@ -549,6 +550,8 @@ def set_context(**kwargs):
|
|||
save_graphs_path auto_tune_mode
|
||||
env_config_path graph_kernel_flags
|
||||
grad_for_scalar
|
||||
save_compile_cache
|
||||
load_compile_cache
|
||||
=========================== =========================== =================
|
||||
|
||||
Args:
|
||||
|
@ -633,6 +636,11 @@ def set_context(**kwargs):
|
|||
- rl_tune: Reinforecement Learning tune.
|
||||
- ga_tune: Genetic Algorithm tune.
|
||||
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
|
||||
save_compile_cache (bool): Experimental. Whether to cache the graph compiled by frontend. Default: False.
|
||||
load_compile_cache (bool): Experimental. Whether to use the cache of the graph compiled by frontend.
|
||||
When it is true, the graph compilation will skip the frontend compilation process. It means that
|
||||
you should make sure the network has not been changed since the last execution. Currently we have
|
||||
not support automatic checking the changes yet. Default: False.
|
||||
|
||||
Raises:
|
||||
ValueError: If input key is not an attribute in context.
|
||||
|
|
|
@ -139,7 +139,11 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
|
|||
std::fstream input_graph(abs_path_buff, std::ios::in | std::ios::binary);
|
||||
mind_ir::ModelProto origin_model;
|
||||
|
||||
if (!input_graph || !origin_model.ParseFromIstream(&input_graph)) {
|
||||
if (!input_graph) {
|
||||
MS_LOG(ERROR) << "Failed to open file: " << file_name;
|
||||
return nullptr;
|
||||
}
|
||||
if (!origin_model.ParseFromIstream(&input_graph)) {
|
||||
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -84,6 +84,8 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
|
|||
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
|
||||
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
|
||||
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
|
||||
set_param<bool>(MS_CTX_SAVE_COMPILE_CACHE, false);
|
||||
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
|
||||
|
||||
backend_policy_ = policy_map_[policy];
|
||||
}
|
||||
|
|
|
@ -86,6 +86,8 @@ enum MsCtxParam : unsigned {
|
|||
MS_CTX_ENABLE_PARALLEL_SPLIT,
|
||||
MS_CTX_ENABLE_INFER_OPT,
|
||||
MS_CTX_GRAD_FOR_SCALAR,
|
||||
MS_CTX_SAVE_COMPILE_CACHE,
|
||||
MS_CTX_LOAD_COMPILE_CACHE,
|
||||
MS_CTX_TYPE_BOOL_END,
|
||||
|
||||
// parameter of type int
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import os
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_compile_cache=True, load_compile_cache=True)
|
||||
|
||||
|
||||
class LeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(LeNet, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
self.batch_size = 32
|
||||
|
||||
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.reshape = P.Reshape()
|
||||
self.fc1 = nn.Dense(400, 120)
|
||||
self.fc2 = nn.Dense(120, 84)
|
||||
self.fc3 = nn.Dense(84, 10)
|
||||
|
||||
def construct(self, input_x):
|
||||
output = self.conv1(input_x)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.conv2(output)
|
||||
output = self.relu(output)
|
||||
output = self.pool(output)
|
||||
output = self.reshape(output, (self.batch_size, -1))
|
||||
output = self.fc1(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc2(output)
|
||||
output = self.relu(output)
|
||||
output = self.fc3(output)
|
||||
return output
|
||||
|
||||
|
||||
def train(net, data, label):
|
||||
learning_rate = 0.01
|
||||
momentum = 0.9
|
||||
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
|
||||
train_network.set_train()
|
||||
res = train_network(data, label)
|
||||
print("+++++++++Loss+++++++++++++")
|
||||
print(res)
|
||||
print("+++++++++++++++++++++++++++")
|
||||
diff = res.asnumpy() - 2.302585
|
||||
assert np.all(diff < 1.e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_lenet():
|
||||
path = "compile_cache.mindir"
|
||||
if os.path.exists(path):
|
||||
os.remove(path)
|
||||
assert not os.path.exists(path)
|
||||
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([32]).astype(np.int32))
|
||||
net = LeNet()
|
||||
train(net, data, label)
|
||||
assert os.path.exists(path)
|
||||
|
||||
data1 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label1 = Tensor(np.ones([32]).astype(np.int32))
|
||||
net1 = LeNet()
|
||||
train(net1, data1, label1)
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "debug/dump_proto.h"
|
||||
#include "proto/mind_ir.pb.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
|
@ -24,4 +25,9 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return "";
|
|||
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
|
||||
|
||||
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph) {
|
||||
mind_ir::ModelProto empty_model;
|
||||
return empty_model;
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue