diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc new file mode 100644 index 00000000000..9f4759d3a78 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "frontend/operator/graph_bprop/bprop_expander_meta_func_graph.h" +#include "frontend/operator/graph_bprop/utils.h" +#include "frontend/operator/graph_bprop/ops_utils.h" +#include "include/common/utils/utils.h" +#include "pipeline/pynative/grad/bprop_expander/bprop.h" + +namespace mindspore { +namespace graph_bprop { +REGISTER_EXPANDER_BPROP_IMPL(Sin); +} // namespace graph_bprop +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.h b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.h new file mode 100644 index 00000000000..6094c29ee40 --- /dev/null +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_expander_meta_func_graph.h @@ -0,0 +1,92 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_ + +#include +#include +#include +#include "ir/meta_func_graph.h" +#include "frontend/operator/graph_bprop/bprop_meta_func_graph.h" +#include "pipeline/pynative/grad/bprop_expander/bprop.h" +#include "frontend/optimizer/expander.h" +#include "include/common/debug/anf_ir_dump.h" + +namespace mindspore { +namespace graph_bprop { +constexpr int64_t TWO = 2; +constexpr int64_t ONE = 1; +class BpropExpanderMetaFuncGraph : public BpropMetaFuncGraph { + public: + explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal) : BpropMetaFuncGraph(primal->name(), primal) {} + ~BpropExpanderMetaFuncGraph() override = default; + MS_DECLARE_PARENT(BpropExpanderMetaFuncGraph, BpropMetaFuncGraph); + FuncGraphPtr BpropExpanderFunc(const AbstractBasePtrList &args_spec_list) { + int64_t list_size = SizeToLong(args_spec_list.size()); + auto bprop_fg = std::make_shared(); + std::vector grads; + grads.push_back(NewValueNode(primal_)); + for (int64_t i = 0; i < list_size - TWO; ++i) { + auto abs_i = args_spec_list[i]; + auto x = bprop_fg->add_parameter(); + x->set_abstract(args_spec_list[i]); + x->abstract()->set_value(args_spec_list[i]->BuildValue()); + (void)grads.emplace_back(x); + } + auto out = bprop_fg->add_parameter(); + + out->set_abstract(args_spec_list[list_size - TWO]); + (void)grads.emplace_back(out); + + auto dout = bprop_fg->NewCNode({NewValueNode(prim::kPrimZerosLike), out}); + dout->set_abstract(args_spec_list[list_size - ONE]); + (void)grads.emplace_back(dout); + auto newcnode = bprop_fg->NewCNode(grads); + expander::bprop::BpropExpanderInGraphMode be; + if (be.Run(newcnode)) { + bprop_fg = be.GetGraph(); + } else { + MS_LOG(WARNING) << "Bprop FuncGraph Create Failed!"; + } + + (void)mindspore::opt::ConvertPrimToPrimPy(bprop_fg); + + return bprop_fg; + } + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) override { + return BpropExpanderFunc(input_abs); + } +}; +FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size) { + auto fg = std::make_shared(); + auto meta_graph = std::make_shared(primal); + std::vector inputs{NewValueNode(meta_graph)}; + for (size_t i = 0; i < forward_inputs_size; ++i) { + (void)inputs.emplace_back(fg->add_parameter()); + } + (void)inputs.emplace_back(fg->add_parameter()); + (void)inputs.emplace_back(fg->add_parameter()); + fg->set_output(fg->NewCNode(inputs)); + return fg; +} + +#define STR(s) #s +#define REGISTER_EXPANDER_BPROP_IMPL(name) \ + static auto helper_expand_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetExpandBprop); +} // namespace graph_bprop +} // namespace mindspore +#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_meta_func_graph.h b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_meta_func_graph.h index c9748885370..9111d5bc729 100644 --- a/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_meta_func_graph.h +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/bprop_meta_func_graph.h @@ -38,23 +38,23 @@ class BpropMetaFuncGraph : public MetaFuncGraph { PrimitivePtr primal_; }; -using BpropFunction = std::function; -using PrimitiveBpropImplMap = mindspore::HashMap; +using BpropFunction = std::function; +using PrimitiveBpropImplMap = mindspore::HashMap; PrimitiveBpropImplMap &GetPrimitiveBpropImplMap(); class RegisterPrimitiveBpropHelper { public: - RegisterPrimitiveBpropHelper(const PrimitivePtr &primitive, const BpropFunction &bprop_fn) { + RegisterPrimitiveBpropHelper(const std::string &op_name, const BpropFunction &bprop_fn) { auto &prim_bprop_impl_map = GetPrimitiveBpropImplMap(); - prim_bprop_impl_map[primitive] = bprop_fn; + prim_bprop_impl_map[op_name] = bprop_fn; } ~RegisterPrimitiveBpropHelper() = default; }; #define STR(s) #s -#define REGISTER_PRIMITIVE_BPROP_IMPL(name, primitive, bprop_fn, forward_inputs_size) \ +#define REGISTER_PRIMITIVE_BPROP_IMPL(name, bprop_fn) \ class BpropMetaFuncGraph##name : public BpropMetaFuncGraph { \ public: \ explicit BpropMetaFuncGraph##name(const PrimitivePtr &primal) \ @@ -66,7 +66,7 @@ class RegisterPrimitiveBpropHelper { return bprop_fn(primal_, input_abs); \ } \ }; \ - FuncGraphPtr GetBprop##name(const PrimitivePtr &primal) { \ + FuncGraphPtr GetBprop##name(const PrimitivePtr &primal, const size_t forward_inputs_size) { \ auto fg = std::make_shared(); \ auto meta_graph = std::make_shared(primal); \ std::vector inputs{NewValueNode(meta_graph)}; \ @@ -78,7 +78,7 @@ class RegisterPrimitiveBpropHelper { fg->set_output(fg->NewCNode(inputs)); \ return fg; \ } \ - static auto helper_bprop_##name = RegisterPrimitiveBpropHelper(primitive, GetBprop##name); + static auto helper_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetBprop##name); } // namespace graph_bprop } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_META_FUNC_GRAPH_H_ diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/grad_array_ops.cc b/mindspore/ccsrc/frontend/operator/graph_bprop/grad_array_ops.cc index 3504a19ad3d..3ded42fdddf 100644 --- a/mindspore/ccsrc/frontend/operator/graph_bprop/grad_array_ops.cc +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/grad_array_ops.cc @@ -62,7 +62,7 @@ FuncGraphPtr TransposeBprop(const PrimitivePtr &primal, const AbstractBasePtrLis fg->set_output(NewNode(fg, {MakeTuple(), transpose, zeros_like})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, prim::kPrimTranspose, TransposeBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, TransposeBprop); FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { constexpr size_t expected_arg_size = 4; @@ -90,6 +90,6 @@ FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in fg->set_output(NewNode(fg, {MakeTuple(), return_node, zeros_like_node})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(Cast, prim::kPrimCast, CastBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(Cast, CastBprop); } // namespace graph_bprop } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/grad_math_ops.cc b/mindspore/ccsrc/frontend/operator/graph_bprop/grad_math_ops.cc index 5d5df6db104..4b778c56d19 100644 --- a/mindspore/ccsrc/frontend/operator/graph_bprop/grad_math_ops.cc +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/grad_math_ops.cc @@ -56,7 +56,7 @@ FuncGraphPtr MatMulBprop(const PrimitivePtr &primal, const AbstractBasePtrList & fg->set_output(NewNode(fg, {MakeTuple(), dx, dw})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(MatMul, prim::kPrimMatMul, MatMulBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(MatMul, MatMulBprop); FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -69,7 +69,7 @@ FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp fg->set_output(BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], neg_dout)); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(Sub, prim::kPrimSub, SubBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(Sub, SubBprop); FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -80,7 +80,7 @@ FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], parameters[kIndex3])); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(Add, prim::kPrimAdd, AddBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(Add, AddBprop); FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -94,7 +94,7 @@ FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrLis fg->set_output(NewNode(fg, {MakeTuple(), out1, out2})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, AssignAddBprop); FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto neg_grad = Neg(); @@ -106,7 +106,7 @@ FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp fg->set_output(NewNode(fg, {MakeTuple(), dx})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(Neg, prim::kPrimNeg, NegBprop, 1); +REGISTER_PRIMITIVE_BPROP_IMPL(Neg, NegBprop); FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -118,6 +118,6 @@ FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrLis fg->set_output(NewNode(fg, {MakeTuple(), dx, dy})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, prim::kPrimLogicalOr, LogicalOrBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, LogicalOrBprop); } // namespace graph_bprop } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/graph_bprop/grad_nn_ops.cc b/mindspore/ccsrc/frontend/operator/graph_bprop/grad_nn_ops.cc index 9fd54d2d79a..b8acf54917d 100644 --- a/mindspore/ccsrc/frontend/operator/graph_bprop/grad_nn_ops.cc +++ b/mindspore/ccsrc/frontend/operator/graph_bprop/grad_nn_ops.cc @@ -35,7 +35,7 @@ FuncGraphPtr ReluBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in fg->set_output(NewNode(fg, {MakeTuple(), dx})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(ReLU, prim::kPrimReLU, ReluBprop, 1); +REGISTER_PRIMITIVE_BPROP_IMPL(ReLU, ReluBprop); FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -60,7 +60,7 @@ FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList & fg->set_output(NewNode(fg, {MakeTuple(), dx, dw})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(Conv2D, prim::kPrimConv2D, Conv2DBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(Conv2D, Conv2DBprop); FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -87,7 +87,7 @@ FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis fg->set_output(NewNode(fg, {MakeTuple(), dx, d_gamma, d_beta})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(LayerNorm, prim::kPrimLayerNorm, LayerNormBprop, 3); +REGISTER_PRIMITIVE_BPROP_IMPL(LayerNorm, LayerNormBprop); FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -103,7 +103,7 @@ FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList fg->set_output(NewNode(fg, {MakeTuple(), dx})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(MaxPool, prim::kPrimMaxPool, MaxPoolBprop, 1); +REGISTER_PRIMITIVE_BPROP_IMPL(MaxPool, MaxPoolBprop); FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -140,7 +140,7 @@ FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis NewNode(fg, {MakeTuple(), dx, dscale, dbias, ZerosLikeFunction(fg, mean), ZerosLikeFunction(fg, variance)})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormBprop, 5); +REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, BatchNormBprop); FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -154,7 +154,7 @@ FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList fg->set_output(NewNode(fg, {MakeTuple(), dout, bais_add_grad})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddBprop, 2); +REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, BiasAddBprop); FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { auto fg = NewGraph(input_abs); @@ -169,6 +169,6 @@ FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in fg->set_output(NewNode(fg, {MakeTuple(), dx})); return fg; } -REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, prim::kPrimGeLU, GeLUBprop, 1); +REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, GeLUBprop); } // namespace graph_bprop } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.cc b/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.cc index 96f2fe12362..9d3e456ef32 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.cc @@ -15,6 +15,7 @@ */ #include "frontend/optimizer/ad/bprop_utils.h" + #include #include #include @@ -27,6 +28,9 @@ #include "utils/system/sha256.h" #include "mindspore/core/load_mindir/load_model.h" #include "pipeline/jit/parse/resolve.h" +#include "pipeline/pynative/grad/bprop_expander/bprop.h" +#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h" +#include "frontend/optimizer/expander.h" #include "include/common/debug/dump_proto.h" #include "frontend/operator/ops.h" #include "frontend/optimizer/irpass.h" @@ -305,7 +309,7 @@ bool CheckMindir(const py::object &obj) { } #endif -FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) { +FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) { // Set a child scope named "grad'PrimitiveName'" for the bprop function, // and add "Gradients" to the front. static const std::string gradients_scope = "Gradients/"; @@ -319,9 +323,17 @@ FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr FuncGraphPtr func_graph = nullptr; if (common::GetEnv("MS_DEV_GET_PYTHON_BPROP") != "1") { const auto &bprop_impl_map = graph_bprop::GetPrimitiveBpropImplMap(); - auto iter = bprop_impl_map.find(prim); + auto iter = bprop_impl_map.find(prim->name()); if (iter != bprop_impl_map.end()) { - func_graph = iter->second(prim); + std::vector node_lists = cnode->inputs(); + auto forward_inputs_size = cnode->inputs().size() - 1; + for (size_t i = 1; i < node_lists.size(); i++) { + auto inputi = node_lists[i]; + if (HasAbstractMonad(inputi)) { + --forward_inputs_size; + } + } + func_graph = iter->second(prim, forward_inputs_size); MS_EXCEPTION_IF_NULL(func_graph); func_graph->set_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop, true); if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) { diff --git a/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.h b/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.h index 8050c98b8bd..00396b3d607 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/bprop_utils.h @@ -31,7 +31,8 @@ void ExportBpropToMindir(const py::object &obj, bool force_update); bool CheckMindir(const py::object &obj); #endif // Get bprop function of a primitive. -FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr); +FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr, + const CNodePtr &cnode = nullptr); } // namespace ad } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_BPROP_MANAGER_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index dcc5df57e3f..9184c663634 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -156,7 +156,7 @@ class KPrim { private: FuncGraphPtr GetFprop(const PrimitivePtr &prim) const; FuncGraphPtr GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node, - const pipeline::ResourceBasePtr &resources); + const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode = nullptr); FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const; FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const; // Given a bprop rule, do the K mapping. diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 193dac0eba2..cb12e423531 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -30,6 +30,7 @@ #include "pipeline/jit/resource.h" #include "frontend/optimizer/ad/dfunctor.h" #include "frontend/operator/composite/composite.h" +#include "pipeline/pynative/grad/bprop_expander/bprop.h" #include "include/common/utils/utils.h" #include "utils/symbolic.h" #include "utils/ms_context.h" @@ -37,6 +38,7 @@ #include "pipeline/jit/debug/trace.h" #include "utils/anf_utils.h" #include "frontend/optimizer/ad/bprop_utils.h" +#include "frontend/optimizer/expander.h" namespace mindspore { namespace ad { @@ -47,7 +49,7 @@ constexpr char kLiftedUserDataKey[] = "lifted_from_fv"; } // namespace FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node, - const pipeline::ResourceBasePtr &resources) { + const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(value_node); auto iter = bprop_registry_.find(prim); @@ -55,7 +57,7 @@ FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &v return iter->second; } - FuncGraphPtr bprop_fg = GetBprop(prim, resources); + FuncGraphPtr bprop_fg = GetBprop(prim, resources, cnode); if (bprop_fg != nullptr) { // Set bprop_g graph cache bprop_registry_[prim] = bprop_fg; @@ -218,7 +220,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ } bprop_fg = BpropCut(value_node, resources); } else { - bprop_fg = GetPrimBprop(prim, value_node, resources); + bprop_fg = GetPrimBprop(prim, value_node, resources, cnode); } SetDumpFlag(prim, bprop_fg); diff --git a/mindspore/ccsrc/frontend/optimizer/expander.cc b/mindspore/ccsrc/frontend/optimizer/expander.cc index 200a874d410..60bef3997fc 100644 --- a/mindspore/ccsrc/frontend/optimizer/expander.cc +++ b/mindspore/ccsrc/frontend/optimizer/expander.cc @@ -23,6 +23,7 @@ #include "mindspore/core/utils/anf_utils.h" #include "frontend/parallel/auto_parallel/costmodel.h" #include "frontend/parallel/graph_util/generate_graph.h" +#include "frontend/operator/ops_front_infer_function.h" #include "pybind_api/ir/primitive_py.h" #include "common/graph_kernel/adapter/expander.h" #include "utils/ms_context.h" @@ -54,6 +55,12 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) { if (primitive == nullptr || primitive->isa()) { continue; } + if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) { + continue; + } + if (primitive->isa()) { + continue; + } parallel::OperatorAttrs attrs; const auto iter = op2attrs.find(primitive->name()); if (iter != op2attrs.end()) { diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.cc b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.cc index 71b2009f5b2..918d1a19d87 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.cc +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.cc @@ -27,17 +27,20 @@ namespace expander { namespace bprop { bool BpropExpander::Run(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(outputs_); MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope(); bool ret = true; - outputs_->clear(); + if (outputs_ != nullptr) { + outputs_->clear(); + } try { ret = RunBprop(cnode); } catch (const std::exception &e) { auto node_name = AnfUtils::GetCNodeName(cnode); MS_LOG(DEBUG) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]"; MS_LOG(INFO) << "Python bprop will be used for \"" << node_name << "\""; - outputs_->clear(); + if (outputs_ != nullptr) { + outputs_->clear(); + } ret = false; } MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope(); @@ -56,51 +59,55 @@ const std::vector &BpropExpander::GetUnusedInputs(const CNodePtr &cnode) return handle->unused_inputs; } -NodePtrList BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) { - NodePtrList nodes; - nodes.reserve(cnode->size()); - (void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes), +void BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) { + input_nodes_.reserve(cnode->size()); + (void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_), [ir_builder](const AnfNodePtr &no) { return std::make_shared(no, ir_builder); }); - return nodes; +} + +std::unique_ptr BpropExpander::CreateIRBuilder(const std::string &name, const CNodePtr &cnode, + const std::shared_ptr &infer) { + return std::make_unique(name, cnode->func_graph(), infer); } bool BpropExpander::RunBprop(const CNodePtr &cnode) { auto infer = std::make_shared(); auto name = AnfUtils::GetCNodeName(cnode); - auto ir_builder = std::make_unique(name, cnode->func_graph(), infer); - auto inputs = ExtractInputs(cnode, ir_builder.get()); + auto ir_builder = CreateIRBuilder(name, cnode, infer); + ExtractInputs(cnode, ir_builder.get()); auto &attrs = GetCNodePrimitive(cnode)->attrs(); auto handle = GetBpropHandle(name); if (handle == nullptr) { MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander."; return false; } - auto output_nodes = ir_builder->Run(inputs, attrs, *handle); - if (output_nodes.empty()) { + output_nodes_ = ir_builder->Run(input_nodes_, attrs, *handle); + if (output_nodes_.empty()) { MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty."; return false; } - outputs_->reserve(output_nodes.size()); - (void)std::transform(output_nodes.cbegin(), output_nodes.cend(), std::back_inserter(*outputs_), - [](const NodePtr &node) { - auto cnode = node->get(); - MS_EXCEPTION_IF_NULL(cnode); - return cnode; - }); - PostProcess(inputs); - DumpResult(name, inputs); + PostProcess(); + DumpResult(name); + input_nodes_.clear(); return true; } -void BpropExpander::PostProcess(const NodePtrList &inputs) const { +void BpropExpander::PostProcess() const { + outputs_->reserve(output_nodes_.size()); + (void)std::transform(output_nodes_.cbegin(), output_nodes_.cend(), std::back_inserter(*outputs_), + [](const NodePtr &node) { + auto cnode = node->get(); + return cnode; + }); std::set visited; // do not visit the inputs again. - std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); }); + std::for_each(input_nodes_.cbegin(), input_nodes_.cend(), + [&visited](const NodePtr &node) { visited.insert(node->get()); }); std::queue que; std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); }); - AnfNodePtr dout = inputs.back()->get(); + AnfNodePtr dout = input_nodes_.back()->get(); while (!que.empty()) { auto node = que.front(); que.pop(); @@ -138,7 +145,7 @@ void BpropExpander::PostProcess(const NodePtrList &inputs) const { } } -void BpropExpander::DumpResult(const std::string &name, const NodePtrList &inputs) const { +void BpropExpander::DumpResult(const std::string &name) const { static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on"); if (!dump_result) { return; @@ -146,7 +153,7 @@ void BpropExpander::DumpResult(const std::string &name, const NodePtrList &input auto fg = std::make_shared(); std::map node_map; CNodePtrList newcnodes; - for (auto &inp : inputs) { + for (auto &inp : input_nodes_) { auto p = fg->add_parameter(); p->set_abstract(inp->get()->abstract()); node_map[inp->get()] = p; @@ -211,6 +218,45 @@ void BpropExpander::DumpResult(const std::string &name, const NodePtrList &input } } +void BpropExpanderInGraphMode::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) { + input_nodes_.reserve(cnode->size()); + + (void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_), + [ir_builder, this](const AnfNodePtr &no) { + auto p = this->fg_->add_parameter(); + p->set_abstract(no->abstract()); + return std::make_shared(p, ir_builder); + }); +} + +std::unique_ptr BpropExpanderInGraphMode::CreateIRBuilder(const std::string &name, + const CNodePtr &cnode, + const std::shared_ptr &infer) { + fg_ = std::make_shared(); + return std::make_unique(name, fg_, infer); +} + +void BpropExpanderInGraphMode::PostProcess() const { + AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)}; + AbstractBasePtrList abs; + (void)std::transform(output_nodes_.cbegin(), output_nodes_.cend(), std::back_inserter(new_outputs), + [&abs](const NodePtr &node) { + abs.push_back(node->get()->abstract()); + return node->get(); + }); + auto mt = fg_->NewCNode(new_outputs); + mt->set_abstract(std::make_shared(abs)); + fg_->set_output(mt); +} + +void BpropExpanderInGraphMode::DumpResult(const std::string &name) const { + static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on"); + if (!dump_result) { + return; + } + DumpIR("bprop/bprop_expander_" + name + ".ir", fg_, true); +} + #ifdef _MSC_VER void RegGradArrayOps(); void RegGradClipOps(); diff --git a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.h b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.h index aa4a491f396..d36f1f87784 100644 --- a/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.h +++ b/mindspore/ccsrc/pipeline/pynative/grad/bprop_expander/bprop.h @@ -36,20 +36,39 @@ class BpropExpander { bool Run(const CNodePtr &cnode); const std::vector &GetUnusedInputs(const CNodePtr &cnode) const; - private: + protected: bool RunBprop(const CNodePtr &cnode); - NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder); + virtual void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder); + virtual std::unique_ptr CreateIRBuilder(const std::string &name, const CNodePtr &cnode, + const std::shared_ptr &infer); const BpropHandle *GetBpropHandle(const std::string &name) const { return BpropIRBuilderFactory::Instance().GetBuilder(name); } - void PostProcess(const NodePtrList &inputs) const; - void DumpResult(const std::string &name, const NodePtrList &inputs) const; - - private: + virtual void PostProcess() const; + virtual void DumpResult(const std::string &name) const; + NodePtrList input_nodes_; + // outputs_ must be CNodePtrList, but output_nodes_ may not necessary. output_nodes_ are used to + // create bprop func_graph in graph_mode. + NodePtrList output_nodes_; CNodePtrList *outputs_{nullptr}; UserType *users_{nullptr}; }; +class BpropExpanderInGraphMode : public BpropExpander { + public: + BpropExpanderInGraphMode() {} + ~BpropExpanderInGraphMode() = default; + FuncGraphPtr GetGraph() { return fg_; } + + protected: + FuncGraphPtr fg_{nullptr}; + void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) override; + std::unique_ptr CreateIRBuilder(const std::string &name, const CNodePtr &cnode, + const std::shared_ptr &infer) override; + void PostProcess() const override; + void DumpResult(const std::string &name) const override; +}; + #ifdef _MSC_VER class WinBpropRegister { public: diff --git a/mindspore/core/expander/emitter.cc b/mindspore/core/expander/emitter.cc index f96f55c4e8c..789684471ac 100644 --- a/mindspore/core/expander/emitter.cc +++ b/mindspore/core/expander/emitter.cc @@ -143,6 +143,39 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const { return Emit(prim::kZerosLike, {Tensor(0)}); } } + if (node->isa()) { + if (node->get()->abstract()->isa()) { + return Emit(prim::kZerosLike, {node}); + } + if (node->get()->abstract()->isa()) { + NodePtrList list; + auto abstract_tuple = node->get()->abstract()->cast(); + for (auto &e : abstract_tuple->elements()) { + if (e->isa()) { + auto shape = e->BuildShape()->cast()->shape(); + auto type = e->BuildType()->cast()->element(); + list.emplace_back(Emit("Zeros", {EmitValue(MakeValue(shape)), EmitValue(type)})); + } else if (e->isa()) { + list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())})); + } else { + MS_LOG(WARNING) << "ZerosLike got UNKNOWN TYPE: " << e->ToString(); + list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())})); + } + } + return MakeTuple(list); + } + if (node->get()->abstract()->isa()) { + return Emit(prim::kZerosLike, {Tensor(0)}); + } + auto v = node->get()->abstract()->BuildValue(); + if (v->isa() || v->isa()) { + return Emit(prim::kZerosLike, {Tensor(0, v->type())}); + } + if (v->isa()) { + auto sh = GetValue>(v); + return Emit(prim::kZerosLike, {Tensor(sh)}); + } + } return Emit(prim::kZerosLike, {node}); }