!17173 Add frontend compile cache

From: @ginfung
Reviewed-by: @hwhewei,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-05-31 10:26:16 +08:00 committed by Gitee
commit aa2093fb7d
12 changed files with 225 additions and 8 deletions

View File

@ -19,6 +19,7 @@
#include <string>
#include "ir/func_graph.h"
#include "proto/mind_ir.pb.h"
namespace mindspore {
std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph);
@ -27,6 +28,8 @@ std::string GetOnnxProtoString(const FuncGraphPtr &func_graph);
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph);
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph);
void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix);
} // namespace mindspore

View File

@ -836,6 +836,15 @@ std::vector<ActionItem> VmPipeline() {
return actions;
}
std::vector<ActionItem> BackendPipeline() {
std::vector<ActionItem> actions;
// compile the ANF graph
actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
// to execute the graph
actions.emplace_back(std::make_pair("execute", ExecuteAction));
return actions;
}
#if (ENABLE_CPU && !_WIN32)
std::vector<ActionItem> ServerPipeline() {
auto actions = CommonPipeline();

View File

@ -49,6 +49,7 @@ bool StartServerAction(const ResourcePtr &res);
std::vector<ActionItem> GePipeline();
std::vector<ActionItem> VmPipeline();
std::vector<ActionItem> BackendPipeline();
std::vector<ActionItem> PServerPipeline();
std::vector<ActionItem> ServerPipeline();
std::vector<ActionItem> PSchedulerPipeline();

View File

@ -101,6 +101,8 @@ std::unordered_map<abstract::AbstractBasePtrList, int64_t, abstract::AbstractBas
g_args_cache;
namespace {
constexpr char kCompileCacheFilePath[] = "compile_cache.mindir";
std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
std::ostringstream oss;
int spaces = 2;
@ -155,21 +157,71 @@ std::string GetCompileExceptionInfo() {
return oss.str();
}
void SetGpuLoopSink(const ResourcePtr &resource_) {
auto func_graph = resource_->func_graph();
void SetGpuLoopSink(const ResourcePtr &resource) {
auto func_graph = resource->func_graph();
if (func_graph != nullptr && func_graph->manager() != nullptr) {
auto manager = func_graph->manager();
size_t graph_nums = manager->func_graphs().size();
int64_t sinksize = ConfigManager::GetInstance().iter_num();
if (graph_nums == 1) {
resource_->set_gpu_loopsink(true, sinksize);
resource->set_gpu_loopsink(true, sinksize);
} else {
resource_->set_gpu_loopsink(false, sinksize);
resource->set_gpu_loopsink(false, sinksize);
}
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource_->gpu_loopsink_flag() << ", set loopsink size to "
MS_LOG(INFO) << "Change gpu_loopsink_flag_ to " << resource->gpu_loopsink_flag() << ", set loopsink size to "
<< sinksize;
}
}
void GetCachedFuncGraph(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto realpath = Common::GetRealPath(kCompileCacheFilePath);
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
}
std::ifstream f(realpath.value());
bool cache_file_existed = f.good();
f.close();
if (!cache_file_existed) {
MS_LOG(WARNING) << "The compilation cache file '" << realpath.value()
<< "' dose not exist. Execute all the compilation actions.";
return;
}
MS_LOG(INFO) << "Use the compilation cache \"" << realpath.value() << "\" and execute the backend actions only.";
FuncGraphPtr fg = LoadMindIR(realpath.value());
if (fg == nullptr) {
MS_LOG(EXCEPTION) << "Failed to load the compilation cache file: " << realpath.value();
}
FuncGraphManagerPtr mng = fg->manager();
if (mng == nullptr) {
auto res_mng = resource->manager();
MS_EXCEPTION_IF_NULL(res_mng);
res_mng->AddFuncGraph(fg);
fg->set_manager(res_mng);
}
resource->set_func_graph(fg);
}
void CacheFuncGraph(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
auto realpath = Common::GetRealPath(kCompileCacheFilePath);
if (!realpath.has_value()) {
MS_LOG(EXCEPTION) << "Get real path failed. filename=" << kCompileCacheFilePath;
}
ChangeFileMode(realpath.value(), S_IRWXU);
std::ofstream fout(realpath.value());
if (!fout.is_open()) {
MS_LOG(EXCEPTION) << "Open cache file '" << realpath.value() << "' failed!";
}
FuncGraphPtr fg = resource->func_graph();
mind_ir::ModelProto fg_model = GetBinaryProto(fg);
if (!fg_model.SerializeToOstream(&fout)) {
MS_LOG(EXCEPTION) << "Failed to cache the graph to file " << realpath.value();
}
fout.close();
ChangeFileMode(realpath.value(), S_IRUSR);
}
} // namespace
py::tuple GenerateKey(const std::string &name, const std::unordered_map<std::string, py::object> &defaults) {
@ -556,6 +608,11 @@ std::vector<ActionItem> GetPipeline(const ResourcePtr &resource, const std::stri
// Connect session to debugger
backend_ptr->SetDebugger();
resource->results()[kBackend] = backend_ptr;
// If the 'use_frontend_compile_cache' context has been set true and the cache is read successfully,
// do the backend actions only.
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE) && resource->func_graph() != nullptr) {
return BackendPipeline();
}
return VmPipeline();
}
return GePipeline();
@ -583,6 +640,10 @@ bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, cons
MS_LOG(INFO) << "ExecutorPy compile phase:" << phase_s << "!";
ResourcePtr resource = std::make_shared<Resource>(obj);
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_LOAD_COMPILE_CACHE)) {
GetCachedFuncGraph(resource);
}
auto p_actions = GetPipeline(resource, phase_s, use_vm);
std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, FilterActions(p_actions, phase_s));
@ -711,6 +772,10 @@ void Pipeline::Run() {
};
if (action.first == "task_emit") {
SetGpuLoopSink(resource_);
} else if (action.first == "validate") {
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_COMPILE_CACHE)) {
CacheFuncGraph(resource_);
}
}
if (!result) {
MS_LOG(EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;

View File

@ -98,7 +98,9 @@ REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("env_config_path", MsCtxParam::MS_CTX_ENV_CONFIG_PATH)
.value("graph_kernel_flags", MsCtxParam::MS_CTX_GRAPH_KERNEL_FLAGS)
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR);
.value("grad_for_scalar", MsCtxParam::MS_CTX_GRAD_FOR_SCALAR)
.value("save_compile_cache", MsCtxParam::MS_CTX_SAVE_COMPILE_CACHE)
.value("load_compile_cache", MsCtxParam::MS_CTX_LOAD_COMPILE_CACHE);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified parameter.")

View File

@ -71,6 +71,7 @@ class IrExporter {
explicit IrExporter(IrExportBuilderPtr builder) : builder_(builder) {}
virtual ~IrExporter() = default;
std::string GetDumpString(const FuncGraphPtr &func_graph);
mind_ir::ModelProto GetDumpProto(const FuncGraphPtr &func_graph);
private:
IrExportBuilderPtr builder_;
@ -83,6 +84,7 @@ class IrExportBuilder {
std::string GetProtoString(const FuncGraphPtr &func_graph);
void BuildModelInfo();
void BuildModel(const FuncGraphPtr &func_graph);
mind_ir::ModelProto Model() { return model_; }
private:
void BuildFuncGraph(const FuncGraphPtr &func_graph, mind_ir::GraphProto *const graph_proto);
@ -146,6 +148,20 @@ std::string IrExporter::GetDumpString(const FuncGraphPtr &func_graph) {
return builder_->GetProtoString(func_graph);
}
mind_ir::ModelProto IrExporter::GetDumpProto(const FuncGraphPtr &func_graph) {
if ((builder_ == nullptr) || (func_graph == nullptr)) {
MS_LOG(EXCEPTION) << "Input params is null.";
}
// Export model info
builder_->BuildModelInfo();
// Export model and return string
builder_->BuildModel(func_graph);
return builder_->Model();
}
std::string IrExportBuilder::GetProtoString(const FuncGraphPtr &func_graph) {
MS_LOG(DEBUG) << "BuildModel complete!";
return model_.SerializeAsString();
@ -729,4 +745,9 @@ std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) {
}
return exporter->GetDumpString(func_graph);
}
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph) {
auto exporter = std::make_shared<IrExporter>(std::make_shared<IrExportBuilder>());
return exporter->GetDumpProto(func_graph);
}
} // namespace mindspore

View File

@ -513,7 +513,8 @@ def _check_target_specific_cfgs(device, arg_key):
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str,
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str)
enable_sparse=bool, max_call_depth=int, env_config_path=str, graph_kernel_flags=str,
save_compile_cache=bool, load_compile_cache=bool)
def set_context(**kwargs):
"""
Set context for running environment.
@ -549,6 +550,8 @@ def set_context(**kwargs):
save_graphs_path auto_tune_mode
env_config_path graph_kernel_flags
grad_for_scalar
save_compile_cache
load_compile_cache
=========================== =========================== =================
Args:
@ -633,6 +636,11 @@ def set_context(**kwargs):
- rl_tune: Reinforecement Learning tune.
- ga_tune: Genetic Algorithm tune.
grad_for_scalar (bool): Whether to get gradient for scalar. Default: False.
save_compile_cache (bool): Experimental. Whether to cache the graph compiled by frontend. Default: False.
load_compile_cache (bool): Experimental. Whether to use the cache of the graph compiled by frontend.
When it is true, the graph compilation will skip the frontend compilation process. It means that
you should make sure the network has not been changed since the last execution. Currently we have
not support automatic checking the changes yet. Default: False.
Raises:
ValueError: If input key is not an attribute in context.

View File

@ -139,7 +139,11 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
std::fstream input_graph(abs_path_buff, std::ios::in | std::ios::binary);
mind_ir::ModelProto origin_model;
if (!input_graph || !origin_model.ParseFromIstream(&input_graph)) {
if (!input_graph) {
MS_LOG(ERROR) << "Failed to open file: " << file_name;
return nullptr;
}
if (!origin_model.ParseFromIstream(&input_graph)) {
MS_LOG(ERROR) << "Load MindIR file failed, please check the correctness of the file.";
return nullptr;
}

View File

@ -84,6 +84,8 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);
set_param<bool>(MS_CTX_SAVE_COMPILE_CACHE, false);
set_param<bool>(MS_CTX_LOAD_COMPILE_CACHE, false);
backend_policy_ = policy_map_[policy];
}

View File

@ -86,6 +86,8 @@ enum MsCtxParam : unsigned {
MS_CTX_ENABLE_PARALLEL_SPLIT,
MS_CTX_ENABLE_INFER_OPT,
MS_CTX_GRAD_FOR_SCALAR,
MS_CTX_SAVE_COMPILE_CACHE,
MS_CTX_LOAD_COMPILE_CACHE,
MS_CTX_TYPE_BOOL_END,
// parameter of type int

View File

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_compile_cache=True, load_compile_cache=True)
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
def train(net, data, label):
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
print("+++++++++Loss+++++++++++++")
print(res)
print("+++++++++++++++++++++++++++")
diff = res.asnumpy() - 2.302585
assert np.all(diff < 1.e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_onecard
def test_lenet():
path = "compile_cache.mindir"
if os.path.exists(path):
os.remove(path)
assert not os.path.exists(path)
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)
assert os.path.exists(path)
data1 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label1 = Tensor(np.ones([32]).astype(np.int32))
net1 = LeNet()
train(net1, data1, label1)

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "debug/dump_proto.h"
#include "proto/mind_ir.pb.h"
namespace mindspore {
@ -24,4 +25,9 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { return "";
std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { return ""; }
std::string GetBinaryProtoString(const FuncGraphPtr &func_graph) { return ""; }
mind_ir::ModelProto GetBinaryProto(const FuncGraphPtr &func_graph) {
mind_ir::ModelProto empty_model;
return empty_model;
}
} // namespace mindspore