parallel_support_compile_cache

This commit is contained in:
lichen 2023-01-04 11:47:54 +08:00
parent 6c9fd1bd63
commit 36a1c3538f
5 changed files with 65 additions and 4 deletions

View File

@ -231,7 +231,7 @@ class MicroStepAllGatherPass : public AnfVisitor {
auto attrs = prim->attrs();
std::string group = attrs[parallel::GROUP]->ToString();
if (group.empty()) {
return nullptr;
return inputs[1];
}
auto fusion = attrs[parallel::FUSION];
bool contain_recompute = prim->HasAttr(parallel::RECOMPUTE);

View File

@ -34,6 +34,7 @@
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/graph_util/graph_info.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/graph_util/get_parallel_info.h"
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
#include "frontend/parallel/node_check.h"
#include "ir/param_info.h"
@ -416,6 +417,58 @@ void SliceParameterObj(const ParameterPtr &parameter, const TensorLayoutPtr &ten
}
}
static void SliceCacheParameterObj(const ParameterPtr &parameter, const py::dict &layout_dict) {
auto param_info = parameter->param_info();
if (param_info == nullptr) {
MS_LOG(WARNING) << "parameter: " << parameter->DebugString() << " doesn't have param_info.";
return;
}
auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
MS_EXCEPTION_IF_NULL(graph_executor);
auto phase = graph_executor->phase();
auto py_obj = GetPyParameterObj(param_info, OBJ);
if (py::isinstance<py::none>(py_obj)) {
MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
return;
}
auto name = parameter->name();
if (!layout_dict.contains(name)) {
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, INIT_OPTIMIZER_STATE_FN, py_obj, py::str(phase));
return;
}
auto layout = layout_dict[py::str(name)];
// Call Python _slice_parameter Fn to slice python parameter obj
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
// handle cloned parameter, like accu_grad and optimizer param
auto cloned_py_obj = GetPyParameterObj(param_info, CLONED_OBJ);
if (!py::isinstance<py::none>(cloned_py_obj)) {
if (!py::isinstance<py::list>(cloned_py_obj)) {
MS_LOG(EXCEPTION) << "parameter: " << parameter->DebugString() << " doesn't have correct cloned obj";
}
auto obj_list = py::cast<py::list>(cloned_py_obj);
for (size_t i = 0; i < obj_list.size(); ++i) {
py::object each_cloned_obj = obj_list[i];
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, each_cloned_obj, py::str(phase),
layout);
}
}
}
void InitCompileCacheParams(const pipeline::ResourcePtr &resource) {
auto layout_dict = GetParameterLayoutFromResource(resource);
auto graph = resource->func_graph();
auto params = graph->parameters();
for (auto &param : params) {
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!param_ptr->has_default()) {
continue;
}
SliceCacheParameterObj(param_ptr, layout_dict);
}
}
void InitPynativeNoShardParams(const FuncGraphPtr &root) {
auto parameters = root->parameters();
for (auto &parameter : parameters) {

View File

@ -25,6 +25,7 @@
#include "base/base.h"
#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/step_parallel_utils.h"
#include "pipeline/jit/resource.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
@ -54,6 +55,7 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root);
void AutoParallelPostProcess(const FuncGraphPtr &root);
// Init the parameters for graph which not specified by shard under PyNative mode.
void InitPynativeNoShardParams(const FuncGraphPtr &root);
void InitCompileCacheParams(const pipeline::ResourcePtr &resource);
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
const std::string &name);

View File

@ -761,7 +761,7 @@ void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const st
#endif
}
void GraphExecutorPy::ParallelPostProcess(const std::string &phase) {
void GraphExecutorPy::ParallelPostProcess(const std::string &phase, bool use_compile_cache) {
// Slice Python parameter obj
auto layout_graph = phase + kStepParallelGraph;
// only Parallel graph has tensor_layout
@ -770,6 +770,11 @@ void GraphExecutorPy::ParallelPostProcess(const std::string &phase) {
if (phase.find("after_shard") != std::string::npos) {
after_shard = true;
}
// Use compile cache
if (use_compile_cache) {
parallel::InitCompileCacheParams(info_[phase]->resource);
return;
}
// Initialize parameters for graph which auto-parallel not care.
if (root == nullptr && !after_shard) {
auto graph = info_[phase]->resource->func_graph();
@ -823,6 +828,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
ResourcePtr resource = std::make_shared<Resource>(source_obj);
InitCompileCacheInfo(resource, phase);
bool use_compile_cache = resource->EnableCompileCache() && resource->func_graph();
ConfigManager::GetInstance().ResetQueue(queue_name_);
auto actions = GetPipeline(resource, phase, use_vm);
@ -886,7 +892,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
// Save the compiled graph to MsPipeLine.
SaveCompiledGraph(phase);
if (is_auto_parallel) {
ParallelPostProcess(phase);
ParallelPostProcess(phase, use_compile_cache);
}
#ifdef ENABLE_DUMP_IR
mindspore::RDR::Snapshot();

View File

@ -137,7 +137,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
private:
GraphExecutorPy() = default;
void ParallelPostProcess(const string &phase);
void ParallelPostProcess(const string &phase, bool use_compile_cache);
void GetGeBackendPolicy() const;
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
// 'validate' stage