diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index 0c55b9480c8..347641829dc 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -1199,51 +1199,6 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec return ret_graph; } -FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { - // slice a tensor - // args: tensor, slice or slice tuple - const std::string op_name = std::string("UnpackCall"); - size_t arg_length = args_spec_list.size(); - if (arg_length < 2) { - MS_LOG(EXCEPTION) << "" << op_name << " requires at least two args, but got " << arg_length << "."; - } - - (void)abstract::CheckArg(op_name, args_spec_list, 0); - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); - - AnfNodePtr fnNode = ret_graph->add_parameter(); - std::vector elems; - elems.push_back(fnNode); - for (size_t index = 1; index < arg_length; index++) { - MS_EXCEPTION_IF_NULL(args_spec_list[index]); - if (args_spec_list[index]->isa()) { - AbstractTuplePtr arg_tuple = dyn_cast(args_spec_list[index]); - AnfNodePtr para_tuple = ret_graph->add_parameter(); - for (size_t i = 0; i < arg_tuple->size(); ++i) { - elems.push_back( - ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); - } - } else if (args_spec_list[index]->isa()) { - AbstractDictionaryPtr arg_dict = dyn_cast(args_spec_list[index]); - AnfNodePtr para_dict = ret_graph->add_parameter(); - auto dict_elems = arg_dict->elements(); - (void)std::transform( - dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), - [ret_graph, para_dict](const AbstractAttribute& item) { - return ret_graph->NewCNode( - {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), - ret_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)})}); - }); - } else { - MS_LOG(EXCEPTION) << "" << op_name << " require args should be tuple or dict, but got " - << args_spec_list[index]->ToString(); - } - } - ret_graph->set_output(ret_graph->NewCNode(elems)); - return ret_graph; -} - REGISTER_PYBIND_DEFINE( TupleAdd_, ([](const py::module* m) { (void)py::class_>(*m, "TupleAdd_").def(py::init()); @@ -1258,10 +1213,5 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { (void)py::class_>(*m, "TensorSlice_") .def(py::init()); })); - -REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { - (void)py::class_>(*m, "UnpackCall_") - .def(py::init()); - })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index 176efb0425b..dc8627ba615 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -29,6 +29,7 @@ #include "operator/composite/zip_operation.h" #include "operator/composite/list_append_operation.h" #include "operator/composite/do_signature.h" +#include "operator/composite/unpack_call.h" #include "pipeline/static_analysis/static_analysis.h" #include "utils/misc.h" #include "utils/any.h" @@ -154,7 +155,7 @@ class GradOperation : public MetaFuncGraph { FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector& ptrParams, bool applyJ = false); FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - + bool sens_param() const { return sens_param_; } bool get_all_; bool get_by_list_; bool sens_param_; @@ -208,17 +209,6 @@ class TensorSlice : public MetaFuncGraph { }; using TensorSlicePtr = std::shared_ptr; -// Expand the tuple and dict parameters generated when parsing the function call, -// and generate positional parameters and key-value pairs for function. -class UnpackCall : public MetaFuncGraph { - public: - explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} - ~UnpackCall() override = default; - MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } -}; -using UnpackCallPtr = std::shared_ptr; } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/unpack_call.cc b/mindspore/ccsrc/operator/composite/unpack_call.cc new file mode 100644 index 00000000000..64d6b3433b1 --- /dev/null +++ b/mindspore/ccsrc/operator/composite/unpack_call.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "operator/composite/unpack_call.h" +#include +#include + +#include "./common.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pipeline/static_analysis/dshape.h" +#include "pipeline/static_analysis/param_validator.h" +#include "operator/cc_implementations.h" +#include "ir/anf.h" +#include "optimizer/opt.h" +#include "utils/symbolic.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { +using mindspore::abstract::AbstractAttribute; +using mindspore::abstract::AbstractBase; +using mindspore::abstract::AbstractDictionary; +using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractFunction; +using mindspore::abstract::AbstractKeywordArg; +using mindspore::abstract::AbstractTuple; +using mindspore::abstract::AbstractTuplePtr; + +FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { + // slice a tensor + // args: tensor, slice or slice tuple + const std::string op_name = std::string("UnpackCall"); + size_t arg_length = args_spec_list.size(); + if (arg_length < 2) { + MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << "."; + } + + (void)abstract::CheckArg(op_name, args_spec_list, 0); + auto ret_graph = std::make_shared(); + ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + + AnfNodePtr fnNode = ret_graph->add_parameter(); + std::vector elems; + elems.push_back(fnNode); + for (size_t index = 1; index < arg_length; index++) { + MS_EXCEPTION_IF_NULL(args_spec_list[index]); + if (args_spec_list[index]->isa()) { + auto arg_tuple = args_spec_list[index]->cast(); + AnfNodePtr para_tuple = ret_graph->add_parameter(); + for (size_t i = 0; i < arg_tuple->size(); ++i) { + elems.push_back( + ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))})); + } + } else if (args_spec_list[index]->isa()) { + AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast(); + AnfNodePtr para_dict = ret_graph->add_parameter(); + auto dict_elems = arg_dict->elements(); + (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), + [ret_graph, para_dict](const AbstractAttribute& item) { + auto dict_get_item = ret_graph->NewCNode( + {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); + return ret_graph->NewCNode( + {NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item}); + }); + } else { + MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got " + << args_spec_list[index]->ToString(); + } + } + ret_graph->set_output(ret_graph->NewCNode(elems)); + return ret_graph; +} + +REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { + (void)py::class_>(*m, "UnpackCall_") + .def(py::init()); + })); + +} // namespace prim +} // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h new file mode 100644 index 00000000000..7ec5f9ad33d --- /dev/null +++ b/mindspore/ccsrc/operator/composite/unpack_call.h @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ +#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "pipeline/static_analysis/static_analysis.h" +#include "utils/misc.h" +#include "utils/any.h" +#include "ir/dtype.h" +#include "ir/meta_func_graph.h" +#include "common/utils.h" + +namespace mindspore { +// namespace to support composite operators definition +namespace prim { + +// Expand the tuple and dict parameters generated when parsing the function call, +// and generate positional parameters and key-value pairs for function. +class UnpackCall : public MetaFuncGraph { + public: + explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} + ~UnpackCall() override = default; + MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; + friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } +}; +using UnpackCallPtr = std::shared_ptr; + +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_ diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index f3f5dad5f1d..727d66dfb36 100644 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -246,6 +246,21 @@ class DoSignaturePrimitive : public Primitive { ValuePtr function_; }; using DoSignaturePrimitivePtr = std::shared_ptr; + +class UnpackGraphPrimitive : public Primitive { + public: + explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) + : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} + ~UnpackGraphPrimitive() override = default; + MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) + bool with_sens_in_args() const { return with_sens_in_args_; } + bool need_unpack_args() const { return need_unpack_args_; } + + private: + bool with_sens_in_args_; + bool need_unpack_args_; +}; +using UnpackGraphPrimitivePtr = std::shared_ptr; } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index ba78696b38e..cdc960792ff 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -39,6 +39,7 @@ #include "optimizer/irpass/specialize_transform.h" #include "optimizer/irpass/incorporate_getitem.h" #include "optimizer/irpass/incorporate_call.h" +#include "optimizer/irpass/grad_var_prepare.h" namespace mindspore { namespace opt { @@ -123,6 +124,11 @@ ResolveIRPassLib::ResolveIRPassLib() { resolver_resolve_ = MakeSubstitution(ResolverResolve(), "resolver_resolve", prim::kPrimResolve); resolver_getattr_ = MakeSubstitution(ResolverGetattr(), "resolver_getattr", prim::kPrimGetAttr); } + +InferenceOptPrepareLib::InferenceOptPrepareLib() { + grad_var_prepare_ = MakeSubstitution(GradVarPrepare(), "grad_var_prepare", IsCNode); +} + } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index c2af344e326..bdaf42b3ed1 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -102,6 +102,13 @@ class ResolveIRPassLib { SubstitutionPtr resolver_getattr_; }; +class InferenceOptPrepareLib { + public: + InferenceOptPrepareLib(); + ~InferenceOptPrepareLib() = default; + SubstitutionPtr grad_var_prepare_; +}; + // predicate functions inline bool IsNode(const AnfNodePtr &) { return true; } @@ -151,6 +158,7 @@ inline bool IsCNodeDup(const AnfNodePtr &node) { } return false; } + } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc new file mode 100644 index 00000000000..5daeced3a5e --- /dev/null +++ b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "optimizer/irpass/grad_var_prepare.h" +#include +#include +#include +#include + +#include "operator/composite/composite.h" +#include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, FuncGraphPtr func_graph, + AnfNodePtr func_node, bool is_unpack, bool sens_param) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_node); + std::vector nodes; + AnfNodePtr unpack_graph_node = nullptr; + if (is_unpack) { + auto unpack_graph = std::make_shared("unpack_graph", sens_param, true); + nodes.push_back(NewValueNode(unpack_graph)); + nodes.push_back(func_node); + // {unpackcall, {GradOperation, ...}, args...} + std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), + [](const AnfNodePtr& node) { return node; }); + unpack_graph_node = func_graph->NewCNode(nodes); + } else { + auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); + nodes.push_back(NewValueNode(unpack_graph)); + nodes.push_back(func_node); + // {{GradOperation, ...}, args...} + std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), + [](const AnfNodePtr& node) { return node; }); + unpack_graph_node = func_graph->NewCNode(nodes); + } + return unpack_graph_node; +} + +// get metagraph of value node +MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { + ValuePtr value; + if (IsValueNode(node)) { + value = GetValueNode(node)->cast()->function(); + } else { + value = GetValueNode(node); + } + if (value == nullptr) { + return nullptr; + } + return value->cast(); +} + +// check if node is a specific metafuncgraph op +bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { + if (node != nullptr) { + auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); + if (meta_func_graph_ptr == nullptr) { + return false; + } + + if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) { + return true; + } + } + return false; +} + +// {{GradOperation, g, w}, Ys} +// {UnPackCall, {GradOperation, g, w}, Ys} +AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + // {{...}, Ys} + auto inputs_y = node->cast()->inputs(); + std::vector inputs_x; + if (IsCNode(inputs_y[0])) { + inputs_x = inputs_y[0]->cast()->inputs(); + } else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) { + inputs_x = inputs_y[1]->cast()->inputs(); + } else { + return nullptr; + } + + // {{...}, Xs} + if (inputs_x.size() < 2) { + return nullptr; + } + + // {GradOperation, g, w} or {GradOperation, g} + if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) { + return nullptr; + } + + auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]); + if (meta_func == nullptr) { + return nullptr; + } + auto grad_op_ptr = meta_func->cast(); + auto func_node = inputs_x[1]; + if (!IsValueNode(func_node)) { + return nullptr; + } + + AnfNodePtr unpack_graph_node = + GenerateUnpackGraphNode(inputs_y, node->cast()->func_graph(), func_node, + IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); + // constuct new grad_opration + inputs_x[1] = unpack_graph_node; + auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); + if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { + inputs_y[1] = grad_op_cnode; + } else { + inputs_y[0] = grad_op_cnode; + } + auto cnode = node->func_graph()->NewCNode(inputs_y); + return cnode; +} +} // namespace irpass +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h new file mode 100644 index 00000000000..599d1dca17b --- /dev/null +++ b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ + +#include +#include +#include +#include + +#include "operator/composite/composite.h" +#include "operator/ops.h" +#include "optimizer/irpass.h" +#include "optimizer/optimizer.h" +#include "ir/visitor.h" +#include "ir/func_graph.h" +#include "ir/func_graph_cloner.h" + +namespace mindspore { +namespace opt { +namespace irpass { + +// {{GradOperation, g, w}, Ys} +// {UnPackCall, {GradOperation, g, w}, Ys} +class GradVarPrepare : public AnfVisitor { + public: + GradVarPrepare() + : grad_op_(std::make_shared("grad")), + unpack_op_(std::make_shared("unpack_call")) {} + ~GradVarPrepare() override = default; + + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; + + private: + MetaFuncGraphPtr grad_op_; + MetaFuncGraphPtr unpack_op_; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_ diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index f3742ab6541..126992cb8f9 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -175,10 +175,10 @@ bool CombineLikeGraphs(const ResourcePtr&) { bool SymbolResolveAction(const ResourcePtr& res) { if (res->manager() == nullptr) { - MS_LOG(EXCEPTION) << "Resolve error."; + MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; } if (res->func_graph() == nullptr) { - MS_LOG(EXCEPTION) << "Resolve error"; + MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; } FuncGraphPtr func_graph = res->func_graph(); auto succ = parse::ResolveFuncGraph(func_graph, res); @@ -194,6 +194,16 @@ bool SymbolResolveAction(const ResourcePtr& res) { return succ; } +bool InferenceOptPrepareAction(const ResourcePtr& res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; + } + if (res->func_graph() == nullptr) { + MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null."; + } + return InferenceOptPreparePass(res); +} + bool AbstractSpecializeAction(const ResourcePtr& res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "AbstractSpecialize error"; @@ -303,7 +313,7 @@ static std::vector CommonPipeline() { // Resolve the python func actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); - + actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); // Evaluate type and shape, and specialize actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index 02e8c5277b2..e2626d53145 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -160,6 +160,13 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) { return map; } +OptPassGroupMap GetInferenceOptPreparePhases() { + opt::irpass::InferenceOptPrepareLib irpass; + auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_}); + opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}}); + return prepare_map; +} + OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); OptPassGroupMap map({{"prepare_group", prepare_group}}); @@ -239,6 +246,16 @@ bool ValidatePass(const ResourcePtr& res) { return true; } +bool InferenceOptPreparePass(const ResourcePtr& res) { + FuncGraphPtr func_graph = res->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + abstract::AbstractBasePtrList args_spec = res->args_spec(); + auto prepare_map = GetInferenceOptPreparePhases(); + auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map); + (void)infer_opt_prepare->step(func_graph, args_spec, false); + return true; +} + std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, diff --git a/mindspore/ccsrc/pipeline/pass.h b/mindspore/ccsrc/pipeline/pass.h index 03ed8eb3708..3731d7e524c 100644 --- a/mindspore/ccsrc/pipeline/pass.h +++ b/mindspore/ccsrc/pipeline/pass.h @@ -34,7 +34,7 @@ bool CconvPass(const ResourcePtr& res); bool ValidatePass(const ResourcePtr& res); bool ConvertPrepareAdapt(const ResourcePtr& res); bool AddControlDependPass(const ResourcePtr& res); - +bool InferenceOptPreparePass(const ResourcePtr& res); void ReclaimOptimizer(); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h index 3acb22d8296..133d5e99a9d 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h @@ -133,6 +133,7 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { FuncGraphPtr func_graph_; AnalysisContextPtr context_; }; +using FuncGraphAbstractClosurePtr = std::shared_ptr; class MetaFuncGraphAbstractClosure : public AbstractFuncAtom { public: diff --git a/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc b/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc index 9326ded2d54..aeaa6b17f81 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/analysis_context.cc @@ -41,7 +41,7 @@ AnalysisContextPtr AnalysisContext::NewFuncGraphContext(const FuncGraphPtr &func } else { oss << "nullptr"; } - MS_LOG(EXCEPTION) << "" << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); + MS_LOG(EXCEPTION) << oss.str() << " NodeInfo: " << trace::GetDebugInfo(func_graph->debug_info()); } return NewContext(parent_context, func_graph, args_spec_list); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index 98d82de5d59..4110f258110 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -180,6 +180,85 @@ AbstractBasePtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const Config return engine->ForwardConfig(out_conf, fn_conf); } +static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) { + // arg[0] is the func graph to unpack, ignore it + AbstractBasePtrList sepcialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end()); + AbstractBasePtrList graph_sepcialize_args; + if (need_unpack) { + for (size_t index = 0; index < sepcialize_args_before_unpack.size(); index++) { + MS_EXCEPTION_IF_NULL(sepcialize_args_before_unpack[index]); + if (sepcialize_args_before_unpack[index]->isa()) { + AbstractTuplePtr arg_tuple = sepcialize_args_before_unpack[index]->cast(); + std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(), + std::back_inserter(graph_sepcialize_args), [](AbstractBasePtr abs) { return abs; }); + } else if (sepcialize_args_before_unpack[index]->isa()) { + AbstractDictionaryPtr arg_dict = sepcialize_args_before_unpack[index]->cast(); + auto dict_elems = arg_dict->elements(); + (void)std::transform( + dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_sepcialize_args), + [](const AbstractAttribute &item) { return std::make_shared(item.first, item.second); }); + } else { + MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got " + << sepcialize_args_before_unpack[index]->ToString(); + } + } + } else { + graph_sepcialize_args = sepcialize_args_before_unpack; + } + return graph_sepcialize_args; +} + +AbstractBasePtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, + AnfNodeConfigPtr out_conf) { + if (out_conf->node() == nullptr || !out_conf->node()->isa()) { + MS_LOG(EXCEPTION) << "Node of out_conf should be CNode"; + } + if (!prim_->isa()) { + MS_LOG(EXCEPTION) << "Primitive should be UnpackGraphPrimitive, but got " << prim_->ToString(); + } + + auto unpack_graph = prim_->cast(); + auto out_node = out_conf->node()->cast(); + const auto &out_node_inputs = out_node->inputs(); + if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) { + MS_LOG(EXCEPTION) << "UnpackGraphPrimitive" + << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size() + << ", inputs size " << out_node_inputs.size(); + } + AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()}; + AbstractBasePtrList args_spec_list; + (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); + // get the forward graph + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + AbstractFunctionPtr fn = args_spec_list[0]->cast(); + if (fn == nullptr) { + MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString(); + } + auto real_fn = fn->cast(); + MS_EXCEPTION_IF_NULL(real_fn); + FuncGraphPtr forward_graph = real_fn->func_graph(); + MS_EXCEPTION_IF_NULL(forward_graph); + AbstractBasePtrList graph_sepcialize_args = + GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args()); + + AbstractBasePtrList graph_sepcialize_args_without_sens; + (void)std::transform(graph_sepcialize_args.begin(), + graph_sepcialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0), + std::back_inserter(graph_sepcialize_args_without_sens), [](AbstractBasePtr abs) { return abs; }); + auto new_graph = forward_graph->GenerateGraph(graph_sepcialize_args_without_sens); + engine->func_graph_manager()->AddFuncGraph(new_graph); + ScopePtr scope = kDefaultScope; + if (out_conf != nullptr) { + scope = out_conf->node()->scope(); + } + ScopeGuard scope_guard(scope); + AnfNodePtr new_vnode = NewValueNode(new_graph); + AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context()); + + return engine->ForwardConfig(out_conf, fn_conf); +} + namespace { py::object BuildValue(const ValuePtr &value_ptr) { if (value_ptr == nullptr) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 9dae576a4c2..e154473dbba 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -87,6 +87,21 @@ class DoSignatureEvaluator : public Evaluator { PrimitivePtr prim_; }; +class UnpackGraphEvaluator : public Evaluator { + public: + explicit UnpackGraphEvaluator(const PrimitivePtr primitive) : Evaluator("UnpackGraphEvaluator"), prim_(primitive) {} + ~UnpackGraphEvaluator() override = default; + AbstractBasePtr Run(AnalysisEnginePtr engine, const ConfigPtrList &argrefs, + AnfNodeConfigPtr out_config = nullptr) override; + + AbstractBasePtr Infer(AnalysisEnginePtr, const AbstractBasePtrList &) override { + MS_LOG(EXCEPTION) << "Infer() should not be called, Run() method should be called"; + } + + private: + PrimitivePtr prim_; +}; + bool IsInWhiteList(PrimitivePtr primitive); StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive); diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index 0bfba265db0..49182e8d09a 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -289,6 +289,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr evaluator = std::make_shared(prim); return evaluator; } + if (prim->isa()) { + evaluator = std::make_shared(prim); + return evaluator; + } if (prim->HasPyEvaluator()) { auto prim_py = dyn_cast(prim); if (prim_py != nullptr) { diff --git a/tests/ut/python/parameter_feature/test_var_grad.py b/tests/ut/python/parameter_feature/test_var_grad.py index d51b78ed9dc..12c05d05941 100644 --- a/tests/ut/python/parameter_feature/test_var_grad.py +++ b/tests/ut/python/parameter_feature/test_var_grad.py @@ -19,6 +19,8 @@ from mindspore.nn import Cell from mindspore.ops import operations as P import mindspore.ops.composite as C from mindspore.common.api import _executor +from mindspore.common.parameter import ParameterTuple +from mindspore.common import dtype as mstype context.set_context(mode=context.GRAPH_MODE) @@ -34,3 +36,152 @@ def test_net_vargs_expand(): sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32)) net = AddNet() out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens) + +class VarNet(Cell): + def __init__(self, net): + super(VarNet, self).__init__() + self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) + self.w = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True) + self.net = net + def construct(self, *args): + return self.net(*args)*self.w + self.b + +class SecondNet(Cell): + def __init__(self): + super(SecondNet, self).__init__() + self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) + def construct(self, *args): + res = args[0] + args[1] + return res + self.b2 +def test_all_var_args_grad_with_sens(): + """"test grad_by_list_with_sens with all var args input""" + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.weights = ParameterTuple(net.trainable_params()) + self.net = net + def construct(self, *inputs): + return C.grad_by_list_with_sens(self.net, self.weights)(*inputs) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + sens = Tensor(1.0, dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y, sens) + +def test_grad_list_var_args(): + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.weights = ParameterTuple(net.trainable_params()) + self.net = net + def construct(self, *inputs): + return C.grad_by_list(self.net, self.weights)(*inputs) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y) + +def test_grad_all_var_args(): + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.weights = ParameterTuple(net.trainable_params()) + self.net = net + def construct(self, *inputs): + return C.grad_all(self.net)(*inputs) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y) + +def test_grad_all_var_args_with_sens(): + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.weights = ParameterTuple(net.trainable_params()) + self.net = net + def construct(self, *inputs): + return C.grad_all_with_sens(self.net)(*inputs) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + sens = Tensor(1.0, dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y, sens) + +def test_grad_var_args_with_sens(): + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.weights = ParameterTuple(net.trainable_params()) + self.net = net + def construct(self, *inputs): + return C.grad_with_sens(self.net)(*inputs) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + sens = Tensor(1.0, dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y, sens) + +def test_var_args_grad(): + class VarNet(Cell): + def __init__(self, net): + super(VarNet, self).__init__() + self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True) + self.net = net + def construct(self, *args): + return self.net(*args) + self.b + + class SecondNet(Cell): + def __init__(self): + super(SecondNet, self).__init__() + self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True) + def construct(self, *args): + res = args[0] + args[1] + return res + self.b2 + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + def construct(self, x, y, sens): + return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + sens = Tensor(1.0, dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y, sens) + + +def test_var_args_positional(): + """"test grad_all with var args in inner graph""" + class VarNet(Cell): + def __init__(self, net): + super(VarNet, self).__init__() + self.net = net + def construct(self, x, y): + return self.net(x, y)*x + + class SecondNet(Cell): + def __init__(self): + super(SecondNet, self).__init__() + def construct(self, *args): + return args[0] + args[1] + + class GradNet(Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.weights = ParameterTuple(net.trainable_params()) + def construct(self, x, y): + return C.grad_all(self.net)(x, y) + x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32) + net = VarNet(SecondNet()) + grad_net = GradNet(net) + out = grad_net(x, y)