diff --git a/mindspore/ccsrc/optimizer/optimizer.h b/mindspore/ccsrc/optimizer/optimizer.h index f67466efba0..c4455484c43 100644 --- a/mindspore/ccsrc/optimizer/optimizer.h +++ b/mindspore/ccsrc/optimizer/optimizer.h @@ -17,7 +17,9 @@ #ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ #define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_H_ +#include #include +#include #include #include #include @@ -129,29 +131,38 @@ class Optimizer : public std::enable_shared_from_this { return optimizer; } - FuncGraphPtr step(FuncGraphPtr func_graph, const abstract::AbstractBasePtrList &args_spec, bool use_profile = true) { + FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { // Optimizer step counter; int counter = 1; bool changes = true; while (changes) { changes = false; - auto run_runc = [&counter, &func_graph, &args_spec, &changes, use_profile, this]() { + auto run_runc = [&counter, &func_graph, &changes, use_profile, this]() { for (size_t i = 0; i < passes_.size(); ++i) { const OptPass &opt = passes_[i]; - auto opt_func = [&func_graph, &args_spec, &changes, &opt, this]() { + auto opt_func = [&func_graph, &changes, &opt, this]() { if (opt.is_renormalize()) { auto resource_ptr = std::dynamic_pointer_cast(resource_); if (resource_ptr != nullptr) { + // StepParallel may replace the AbstractValue of the parameters of func_graph, + // So generate the args_spec from parameters. + abstract::AbstractBasePtrList maybe_new_args_spec; if (is_watch_renormalize_) { if (untyped_nodes_.size() > 0) { - func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); + std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), + std::back_inserter(maybe_new_args_spec), + [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); + func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); clear_untyped_nodes(); } else { MS_LOG(INFO) << "Optimizer::step: Skipping Renormalize because untyped_nodes_ is empty."; } } else { - func_graph = pipeline::Renormalize(resource_ptr, func_graph, args_spec); + std::transform(func_graph->parameters().begin(), func_graph->parameters().end(), + std::back_inserter(maybe_new_args_spec), + [](AnfNodePtr param) -> AbstractBasePtr { return param->abstract(); }); + func_graph = pipeline::Renormalize(resource_ptr, func_graph, maybe_new_args_spec); } } } else if (opt(func_graph, shared_from_this())) { diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index c4614dc6478..87588a2820b 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1230,7 +1230,11 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pairToString(); std::shared_ptr parallel_shape = std::make_shared(slice_shape); MS_EXCEPTION_IF_NULL(parallel_shape); - abstract->set_shape(parallel_shape); + // Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis. + auto cloned_abstract = abstract->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(parallel_shape); + parameter->set_abstract(cloned_abstract); TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); ParameterPtr parameter_ptr = parameter->cast(); MS_EXCEPTION_IF_NULL(parameter_ptr); @@ -1330,7 +1334,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); - cloned_parameter_node->abstract()->set_shape(cloned_from_node->abstract()->GetShapeTrack()); + auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(cloned_from_node->abstract()->GetShapeTrack()); + cloned_parameter_node->set_abstract(cloned_abstract); MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned, the be cloned parameter is: " << cloned_from_parameter->name() << ", clone index is: " << cloned_index; @@ -1742,7 +1749,10 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l auto slice_shape = loss_grad_layout.slice_shape().array(); std::shared_ptr parallel_shape = std::make_shared(slice_shape); MS_EXCEPTION_IF_NULL(parallel_shape); - abstract->set_shape(parallel_shape); + auto cloned_abstract = abstract->Clone(); + MS_EXCEPTION_IF_NULL(cloned_abstract); + cloned_abstract->set_shape(parallel_shape); + sens_tensor_node->set_abstract(cloned_abstract); auto sens_tensor_param = sens_tensor_node->cast(); sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); return; diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index 284512c9430..18f186dbb15 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -276,9 +276,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa (void)parse::python_adapter::set_python_scoped(); - abstract::AbstractBasePtrList args_spec; MS_EXCEPTION_IF_NULL(opt_resolve); - (void)opt_resolve->step(func_graph, args_spec, use_profile); + (void)opt_resolve->step(func_graph, use_profile); return true; } diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 6cdf6414433..6ce6c4603d8 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -205,14 +205,15 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) { return false; } - abstract::AbstractBasePtrList args = res->args_spec(); FuncGraphPtr func_graph = res->func_graph(); MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", " << func_graph->get_return()->DebugString(true); InitOpt(res); if (g_pass_opts.find(name) != g_pass_opts.end()) { - res->set_func_graph(g_pass_opts[name]->step(func_graph, args)); + res->set_func_graph(g_pass_opts[name]->step(func_graph)); } + // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to + // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here. return true; } @@ -255,10 +256,9 @@ bool ValidatePass(const ResourcePtr &res) { bool InferenceOptPreparePass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); - abstract::AbstractBasePtrList args_spec = res->args_spec(); auto prepare_map = GetInferenceOptPreparePhases(); auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); - (void)infer_opt_prepare->step(func_graph, args_spec, false); + (void)infer_opt_prepare->step(func_graph, false); return true; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc index 5bad1634d56..402ef980013 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/evaluator.cc @@ -260,7 +260,6 @@ AbstractBasePtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const Config return conf->GetEvaluatedValue(); }); AbstractBasePtr ret = EvalPrim(engine, args_spec_list); - (*cache_)[args_spec_list] = ret; return ret; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 46e088ab11a..1115cd9978d 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -405,6 +405,10 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString(); + const auto &iter = cache_->find(args); + if (iter != cache_->end()) { + return iter->second; + } auto py_args = PreparePyInputs(prim_py_, args); auto pyobj = prim_py_->GetPyObj(); @@ -418,6 +422,7 @@ AbstractBasePtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const A auto res_spec = PyInferRes2Abstract(prim_py_, output); MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << "."; + (*cache_)[args] = res_spec; return res_spec; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index 6230df44a56..4afc3509baa 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -271,6 +271,18 @@ void AnalysisEngine::ClearEvaluatorCache() { MS_EXCEPTION_IF_NULL(evaluator->cache()); evaluator->cache()->clear(); } + for (auto &element : prim_constructors_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } + for (auto &element : prim_py_evaluators_) { + EvaluatorPtr evaluator = element.second; + MS_EXCEPTION_IF_NULL(evaluator); + MS_EXCEPTION_IF_NULL(evaluator->cache()); + evaluator->cache()->clear(); + } } void AnalysisEngine::Clear() { @@ -296,7 +308,17 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr if (prim->HasPyEvaluator()) { auto prim_py = dyn_cast(prim); if (prim_py != nullptr) { - return std::make_shared(prim_py); + if (engine == nullptr) { + return std::make_shared(prim_py); + } + + const auto &iter = engine->prim_py_evaluators_.find(prim_py); + if (iter != engine->prim_py_evaluators_.end()) { + return iter->second; + } + evaluator = std::make_shared(prim_py); + engine->prim_py_evaluators_[prim_py] = evaluator; + return evaluator; } MS_LOG(EXCEPTION) << "The primitive with python evaluator should be a python primitive."; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h index ef4f78e619e..80c63204939 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h @@ -194,6 +194,7 @@ class AnalysisEngine : public std::enable_shared_from_this { const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } AnalysisCache cache_; + std::unordered_map prim_py_evaluators_; private: const PrimEvaluatorMap &prim_constructors_; diff --git a/tests/ut/cpp/optimizer/optimizer_test.cc b/tests/ut/cpp/optimizer/optimizer_test.cc index d700225894d..ca7c589d47d 100644 --- a/tests/ut/cpp/optimizer/optimizer_test.cc +++ b/tests/ut/cpp/optimizer/optimizer_test.cc @@ -57,8 +57,7 @@ TEST_F(TestOptOptimizer, test_step_opt) { true); EXPECT_TRUE(optimizer.get() != nullptr); - abstract::AbstractBasePtrList args; - auto after = optimizer->step(before, args); + auto after = optimizer->step(before); draw::Draw("optimizer_test_expendJ_before.dot", before); draw::Draw("optimizer_test_expendJ_after.dot", after);