forked from mindspore-Ecosystem/mindspore
!264 static_analysis: remove useless cache in TrivialPrimEvaluator and add cache for PythonPrimEvaluator
Merge pull request !264 from xychow/remove-unnecessary-cache-and-add-cache
This commit is contained in:
commit
53d2da5fe4
|
@ -17,7 +17,9 @@
|
|||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -129,29 +131,38 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
|||
return optimizer;
|
||||
}
|
||||
|
||||
FuncGraphPtr step(FuncGraphPtr func_graph, const abstract::AbstractBasePtrList &args_spec, bool use_profile = true) {
|
||||
FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) {
|
||||
// Optimizer step counter;
|
||||
int counter = 1;
|
||||
bool changes = true;
|
||||
|
||||
while (changes) {
|
||||
changes = false;
|
||||
auto run_runc = [&counter, &func_graph, &args_spec, &changes, use_profile, this]() {
|
||||
auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() {
|
||||
for (size_t i = 0; i < passes_.size(); ++i) {
|
||||
const OptPass &opt = passes_[i];
|
||||
auto opt_func = [&func_graph, &args_spec, &changes, &opt, this]() {
|
||||
auto opt_func = [&func_graph, &changes, &opt, this]() {
|
||||
if (opt.is_renormalize()) {
|
||||
auto resource_ptr = std::dynamic_pointer_cast<pipeline::Resource>(resource_);
|
||||
if (resource_ptr != nullptr) {
|
||||
// StepParallel may replace the AbstractValue of the parameters of func_graph,
|
||||
// So generate the args_spec from parameters.
|
||||
abstract::AbstractBasePtrList maybe_new_args_spec;
|
||||
if (is_watch_renormalize_) {
|
||||
if (untyped_nodes_.size() > 0) {
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
|
||||
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
|
||||
std::back_inserter(maybe_new_args_spec),
|
||||
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
|
||||
clear_untyped_nodes();
|
||||
} else {
|
||||
MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty.";
|
||||
}
|
||||
} else {
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec);
|
||||
std::transform(func_graph->parameters().begin(), func_graph->parameters().end(),
|
||||
std::back_inserter(maybe_new_args_spec),
|
||||
[](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); });
|
||||
func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec);
|
||||
}
|
||||
}
|
||||
} else if (opt(func_graph, shared_from_this())) {
|
||||
|
|
|
@ -1230,7 +1230,11 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|||
<< MakeValue(slice_shape)->ToString();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
abstract->set_shape(parallel_shape);
|
||||
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
|
||||
auto cloned_abstract = abstract->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
parameter->set_abstract(cloned_abstract);
|
||||
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
||||
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
||||
|
@ -1330,7 +1334,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|||
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
||||
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
||||
cloned_parameter_node->abstract()->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
||||
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack());
|
||||
cloned_parameter_node->set_abstract(cloned_abstract);
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name()
|
||||
<< " is cloned, the be cloned parameter is: " << cloned_from_parameter->name()
|
||||
<< ", clone index is: " << cloned_index;
|
||||
|
@ -1742,7 +1749,10 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l
|
|||
auto slice_shape = loss_grad_layout.slice_shape().array();
|
||||
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
|
||||
MS_EXCEPTION_IF_NULL(parallel_shape);
|
||||
abstract->set_shape(parallel_shape);
|
||||
auto cloned_abstract = abstract->Clone();
|
||||
MS_EXCEPTION_IF_NULL(cloned_abstract);
|
||||
cloned_abstract->set_shape(parallel_shape);
|
||||
sens_tensor_node->set_abstract(cloned_abstract);
|
||||
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
||||
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
|
||||
return;
|
||||
|
|
|
@ -276,9 +276,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa
|
|||
|
||||
(void)parse::python_adapter::set_python_scoped();
|
||||
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
MS_EXCEPTION_IF_NULL(opt_resolve);
|
||||
(void)opt_resolve->step(func_graph, args_spec, use_profile);
|
||||
(void)opt_resolve->step(func_graph, use_profile);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -205,14 +205,15 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
|
|||
return false;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtrList args = res->args_spec();
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", "
|
||||
<< func_graph->get_return()->DebugString(true);
|
||||
InitOpt(res);
|
||||
if (g_pass_opts.find(name) != g_pass_opts.end()) {
|
||||
res->set_func_graph(g_pass_opts[name]->step(func_graph, args));
|
||||
res->set_func_graph(g_pass_opts[name]->step(func_graph));
|
||||
}
|
||||
// Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to
|
||||
// res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here.
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -255,10 +256,9 @@ bool ValidatePass(const ResourcePtr &res) {
|
|||
bool InferenceOptPreparePass(const ResourcePtr &res) {
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
abstract::AbstractBasePtrList args_spec = res->args_spec();
|
||||
auto prepare_map = GetInferenceOptPreparePhases();
|
||||
auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map);
|
||||
(void)infer_opt_prepare->step(func_graph, args_spec, false);
|
||||
(void)infer_opt_prepare->step(func_graph, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -260,7 +260,6 @@ AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const Config
|
|||
return conf->GetEvaluatedValue();
|
||||
});
|
||||
AbstractBasePtr ret = EvalPrim(engine, args_spec_list);
|
||||
(*cache_)[args_spec_list] = ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
|
|
@ -405,6 +405,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
|
||||
MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
|
||||
|
||||
const auto &iter = cache_->find(args);
|
||||
if (iter != cache_->end()) {
|
||||
return iter->second;
|
||||
}
|
||||
auto py_args = PreparePyInputs(prim_py_, args);
|
||||
|
||||
auto pyobj = prim_py_->GetPyObj();
|
||||
|
@ -418,6 +422,7 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A
|
|||
auto res_spec = PyInferRes2Abstract(prim_py_, output);
|
||||
|
||||
MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
|
||||
(*cache_)[args] = res_spec;
|
||||
return res_spec;
|
||||
}
|
||||
|
||||
|
|
|
@ -271,6 +271,18 @@ void AnalysisEngine::ClearEvaluatorCache() {
|
|||
MS_EXCEPTION_IF_NULL(evaluator->cache());
|
||||
evaluator->cache()->clear();
|
||||
}
|
||||
for (auto &element : prim_constructors_) {
|
||||
EvaluatorPtr evaluator = element.second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->cache());
|
||||
evaluator->cache()->clear();
|
||||
}
|
||||
for (auto &element : prim_py_evaluators_) {
|
||||
EvaluatorPtr evaluator = element.second;
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
MS_EXCEPTION_IF_NULL(evaluator->cache());
|
||||
evaluator->cache()->clear();
|
||||
}
|
||||
}
|
||||
|
||||
void AnalysisEngine::Clear() {
|
||||
|
@ -296,8 +308,18 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
if (prim->HasPyEvaluator()) {
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
if (prim_py != nullptr) {
|
||||
if (engine == nullptr) {
|
||||
return std::make_shared<PythonPrimEvaluator>(prim_py);
|
||||
}
|
||||
|
||||
const auto &iter = engine->prim_py_evaluators_.find(prim_py);
|
||||
if (iter != engine->prim_py_evaluators_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
evaluator = std::make_shared<PythonPrimEvaluator>(prim_py);
|
||||
engine->prim_py_evaluators_[prim_py] = evaluator;
|
||||
return evaluator;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive.";
|
||||
}
|
||||
|
||||
|
|
|
@ -194,6 +194,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; }
|
||||
|
||||
AnalysisCache cache_;
|
||||
std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_;
|
||||
|
||||
private:
|
||||
const PrimEvaluatorMap &prim_constructors_;
|
||||
|
|
|
@ -57,8 +57,7 @@ TEST_F(TestOptOptimizer, test_step_opt) {
|
|||
true);
|
||||
EXPECT_TRUE(optimizer.get() != nullptr);
|
||||
|
||||
abstract::AbstractBasePtrList args;
|
||||
auto after = optimizer->step(before, args);
|
||||
auto after = optimizer->step(before);
|
||||
|
||||
draw::Draw("optimizer_test_expendJ_before.dot", before);
|
||||
draw::Draw("optimizer_test_expendJ_after.dot", after);
|
||||
|
|
Loading…
Reference in New Issue