forked from mindspore-Ecosystem/mindspore
parallel_support_compile_cache
This commit is contained in:
parent
6c9fd1bd63
commit
36a1c3538f
|
@ -231,7 +231,7 @@ class MicroStepAllGatherPass : public AnfVisitor {
|
|||
auto attrs = prim->attrs();
|
||||
std::string group = attrs[parallel::GROUP]->ToString();
|
||||
if (group.empty()) {
|
||||
return nullptr;
|
||||
return inputs[1];
|
||||
}
|
||||
auto fusion = attrs[parallel::FUSION];
|
||||
bool contain_recompute = prim->HasAttr(parallel::RECOMPUTE);
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/graph_util/graph_info.h"
|
||||
#include "frontend/parallel/graph_util/node_info.h"
|
||||
#include "frontend/parallel/graph_util/get_parallel_info.h"
|
||||
#include "frontend/parallel/graph_util/pipeline_split_utils.h"
|
||||
#include "frontend/parallel/node_check.h"
|
||||
#include "ir/param_info.h"
|
||||
|
@ -416,6 +417,58 @@ void SliceParameterObj(const ParameterPtr ¶meter, const TensorLayoutPtr &ten
|
|||
}
|
||||
}
|
||||
|
||||
static void SliceCacheParameterObj(const ParameterPtr ¶meter, const py::dict &layout_dict) {
|
||||
auto param_info = parameter->param_info();
|
||||
if (param_info == nullptr) {
|
||||
MS_LOG(WARNING) << "parameter: " << parameter->DebugString() << " doesn't have param_info.";
|
||||
return;
|
||||
}
|
||||
auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(graph_executor);
|
||||
auto phase = graph_executor->phase();
|
||||
auto py_obj = GetPyParameterObj(param_info, OBJ);
|
||||
if (py::isinstance<py::none>(py_obj)) {
|
||||
MS_LOG(WARNING) << "Parameter: " << parameter->DebugString() << " can't find python obj.";
|
||||
return;
|
||||
}
|
||||
auto name = parameter->name();
|
||||
if (!layout_dict.contains(name)) {
|
||||
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, INIT_OPTIMIZER_STATE_FN, py_obj, py::str(phase));
|
||||
return;
|
||||
}
|
||||
auto layout = layout_dict[py::str(name)];
|
||||
// Call Python _slice_parameter Fn to slice python parameter obj
|
||||
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, py_obj, py::str(phase), layout);
|
||||
|
||||
// handle cloned parameter, like accu_grad and optimizer param
|
||||
auto cloned_py_obj = GetPyParameterObj(param_info, CLONED_OBJ);
|
||||
if (!py::isinstance<py::none>(cloned_py_obj)) {
|
||||
if (!py::isinstance<py::list>(cloned_py_obj)) {
|
||||
MS_LOG(EXCEPTION) << "parameter: " << parameter->DebugString() << " doesn't have correct cloned obj";
|
||||
}
|
||||
auto obj_list = py::cast<py::list>(cloned_py_obj);
|
||||
for (size_t i = 0; i < obj_list.size(); ++i) {
|
||||
py::object each_cloned_obj = obj_list[i];
|
||||
(void)python_adapter::CallPyFn(SLICE_PARAMETER_FN_PATH, SLICE_PARAMETER_FN_NAME, each_cloned_obj, py::str(phase),
|
||||
layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InitCompileCacheParams(const pipeline::ResourcePtr &resource) {
|
||||
auto layout_dict = GetParameterLayoutFromResource(resource);
|
||||
auto graph = resource->func_graph();
|
||||
auto params = graph->parameters();
|
||||
for (auto ¶m : params) {
|
||||
auto param_ptr = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (!param_ptr->has_default()) {
|
||||
continue;
|
||||
}
|
||||
SliceCacheParameterObj(param_ptr, layout_dict);
|
||||
}
|
||||
}
|
||||
|
||||
void InitPynativeNoShardParams(const FuncGraphPtr &root) {
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶meter : parameters) {
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include "base/base.h"
|
||||
#include "frontend/parallel/device_manager.h"
|
||||
#include "frontend/parallel/step_parallel_utils.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
@ -54,6 +55,7 @@ void HandleAdaFactorOpt(const FuncGraphPtr &root);
|
|||
void AutoParallelPostProcess(const FuncGraphPtr &root);
|
||||
// Init the parameters for graph which not specified by shard under PyNative mode.
|
||||
void InitPynativeNoShardParams(const FuncGraphPtr &root);
|
||||
void InitCompileCacheParams(const pipeline::ResourcePtr &resource);
|
||||
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
|
||||
std::pair<AnfNodePtr, bool> FindParameterWithAllgather(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||
const std::string &name);
|
||||
|
|
|
@ -761,7 +761,7 @@ void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const st
|
|||
#endif
|
||||
}
|
||||
|
||||
void GraphExecutorPy::ParallelPostProcess(const std::string &phase) {
|
||||
void GraphExecutorPy::ParallelPostProcess(const std::string &phase, bool use_compile_cache) {
|
||||
// Slice Python parameter obj
|
||||
auto layout_graph = phase + kStepParallelGraph;
|
||||
// only Parallel graph has tensor_layout
|
||||
|
@ -770,6 +770,11 @@ void GraphExecutorPy::ParallelPostProcess(const std::string &phase) {
|
|||
if (phase.find("after_shard") != std::string::npos) {
|
||||
after_shard = true;
|
||||
}
|
||||
// Use compile cache
|
||||
if (use_compile_cache) {
|
||||
parallel::InitCompileCacheParams(info_[phase]->resource);
|
||||
return;
|
||||
}
|
||||
// Initialize parameters for graph which auto-parallel not care.
|
||||
if (root == nullptr && !after_shard) {
|
||||
auto graph = info_[phase]->resource->func_graph();
|
||||
|
@ -823,6 +828,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
|
||||
ResourcePtr resource = std::make_shared<Resource>(source_obj);
|
||||
InitCompileCacheInfo(resource, phase);
|
||||
bool use_compile_cache = resource->EnableCompileCache() && resource->func_graph();
|
||||
ConfigManager::GetInstance().ResetQueue(queue_name_);
|
||||
|
||||
auto actions = GetPipeline(resource, phase, use_vm);
|
||||
|
@ -886,7 +892,7 @@ bool GraphExecutorPy::CompileInner(const py::object &source_obj, const py::tuple
|
|||
// Save the compiled graph to MsPipeLine.
|
||||
SaveCompiledGraph(phase);
|
||||
if (is_auto_parallel) {
|
||||
ParallelPostProcess(phase);
|
||||
ParallelPostProcess(phase, use_compile_cache);
|
||||
}
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
mindspore::RDR::Snapshot();
|
||||
|
|
|
@ -137,7 +137,7 @@ class GraphExecutorPy : public std::enable_shared_from_this<GraphExecutorPy> {
|
|||
|
||||
private:
|
||||
GraphExecutorPy() = default;
|
||||
void ParallelPostProcess(const string &phase);
|
||||
void ParallelPostProcess(const string &phase, bool use_compile_cache);
|
||||
void GetGeBackendPolicy() const;
|
||||
// filter some pipeline actions according to phase, e.g. when exporting onnx, it is no need to execute actions after
|
||||
// 'validate' stage
|
||||
|
|
Loading…
Reference in New Issue