!46797 lazy_expand_funcgraph adapt

Merge pull request !46797 from ZengZitao/lazy_expand_func
This commit is contained in:
i-robot 2023-01-10 03:14:58 +00:00 committed by Gitee
commit 17755b5a61
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 300 additions and 62 deletions

View File

@ -0,0 +1,26 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/operator/graph_bprop/bprop_expander_meta_func_graph.h"
#include "frontend/operator/graph_bprop/utils.h"
#include "frontend/operator/graph_bprop/ops_utils.h"
#include "include/common/utils/utils.h"
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
namespace mindspore {
namespace graph_bprop {
REGISTER_EXPANDER_BPROP_IMPL(Sin);
} // namespace graph_bprop
} // namespace mindspore

View File

@ -0,0 +1,92 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_
#include <string>
#include <memory>
#include <vector>
#include "ir/meta_func_graph.h"
#include "frontend/operator/graph_bprop/bprop_meta_func_graph.h"
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
#include "frontend/optimizer/expander.h"
#include "include/common/debug/anf_ir_dump.h"
namespace mindspore {
namespace graph_bprop {
constexpr int64_t TWO = 2;
constexpr int64_t ONE = 1;
class BpropExpanderMetaFuncGraph : public BpropMetaFuncGraph {
public:
explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal) : BpropMetaFuncGraph(primal->name(), primal) {}
~BpropExpanderMetaFuncGraph() override = default;
MS_DECLARE_PARENT(BpropExpanderMetaFuncGraph, BpropMetaFuncGraph);
FuncGraphPtr BpropExpanderFunc(const AbstractBasePtrList &args_spec_list) {
int64_t list_size = SizeToLong(args_spec_list.size());
auto bprop_fg = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(primal_));
for (int64_t i = 0; i < list_size - TWO; ++i) {
auto abs_i = args_spec_list[i];
auto x = bprop_fg->add_parameter();
x->set_abstract(args_spec_list[i]);
x->abstract()->set_value(args_spec_list[i]->BuildValue());
(void)grads.emplace_back(x);
}
auto out = bprop_fg->add_parameter();
out->set_abstract(args_spec_list[list_size - TWO]);
(void)grads.emplace_back(out);
auto dout = bprop_fg->NewCNode({NewValueNode(prim::kPrimZerosLike), out});
dout->set_abstract(args_spec_list[list_size - ONE]);
(void)grads.emplace_back(dout);
auto newcnode = bprop_fg->NewCNode(grads);
expander::bprop::BpropExpanderInGraphMode be;
if (be.Run(newcnode)) {
bprop_fg = be.GetGraph();
} else {
MS_LOG(WARNING) << "Bprop FuncGraph Create Failed!";
}
(void)mindspore::opt::ConvertPrimToPrimPy(bprop_fg);
return bprop_fg;
}
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) override {
return BpropExpanderFunc(input_abs);
}
};
FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size) {
auto fg = std::make_shared<FuncGraph>();
auto meta_graph = std::make_shared<BpropExpanderMetaFuncGraph>(primal);
std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)};
for (size_t i = 0; i < forward_inputs_size; ++i) {
(void)inputs.emplace_back(fg->add_parameter());
}
(void)inputs.emplace_back(fg->add_parameter());
(void)inputs.emplace_back(fg->add_parameter());
fg->set_output(fg->NewCNode(inputs));
return fg;
}
#define STR(s) #s
#define REGISTER_EXPANDER_BPROP_IMPL(name) \
static auto helper_expand_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetExpandBprop);
} // namespace graph_bprop
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_

View File

@ -38,23 +38,23 @@ class BpropMetaFuncGraph : public MetaFuncGraph {
PrimitivePtr primal_; PrimitivePtr primal_;
}; };
using BpropFunction = std::function<FuncGraphPtr(const PrimitivePtr &)>; using BpropFunction = std::function<FuncGraphPtr(const PrimitivePtr &, const size_t)>;
using PrimitiveBpropImplMap = mindspore::HashMap<PrimitivePtr, BpropFunction, PrimitiveHasher, PrimitiveEqual>; using PrimitiveBpropImplMap = mindspore::HashMap<std::string, BpropFunction>;
PrimitiveBpropImplMap &GetPrimitiveBpropImplMap(); PrimitiveBpropImplMap &GetPrimitiveBpropImplMap();
class RegisterPrimitiveBpropHelper { class RegisterPrimitiveBpropHelper {
public: public:
RegisterPrimitiveBpropHelper(const PrimitivePtr &primitive, const BpropFunction &bprop_fn) { RegisterPrimitiveBpropHelper(const std::string &op_name, const BpropFunction &bprop_fn) {
auto &prim_bprop_impl_map = GetPrimitiveBpropImplMap(); auto &prim_bprop_impl_map = GetPrimitiveBpropImplMap();
prim_bprop_impl_map[primitive] = bprop_fn; prim_bprop_impl_map[op_name] = bprop_fn;
} }
~RegisterPrimitiveBpropHelper() = default; ~RegisterPrimitiveBpropHelper() = default;
}; };
#define STR(s) #s #define STR(s) #s
#define REGISTER_PRIMITIVE_BPROP_IMPL(name, primitive, bprop_fn, forward_inputs_size) \ #define REGISTER_PRIMITIVE_BPROP_IMPL(name, bprop_fn) \
class BpropMetaFuncGraph##name : public BpropMetaFuncGraph { \ class BpropMetaFuncGraph##name : public BpropMetaFuncGraph { \
public: \ public: \
explicit BpropMetaFuncGraph##name(const PrimitivePtr &primal) \ explicit BpropMetaFuncGraph##name(const PrimitivePtr &primal) \
@ -66,7 +66,7 @@ class RegisterPrimitiveBpropHelper {
return bprop_fn(primal_, input_abs); \ return bprop_fn(primal_, input_abs); \
} \ } \
}; \ }; \
FuncGraphPtr GetBprop##name(const PrimitivePtr &primal) { \ FuncGraphPtr GetBprop##name(const PrimitivePtr &primal, const size_t forward_inputs_size) { \
auto fg = std::make_shared<FuncGraph>(); \ auto fg = std::make_shared<FuncGraph>(); \
auto meta_graph = std::make_shared<BpropMetaFuncGraph##name>(primal); \ auto meta_graph = std::make_shared<BpropMetaFuncGraph##name>(primal); \
std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)}; \ std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)}; \
@ -78,7 +78,7 @@ class RegisterPrimitiveBpropHelper {
fg->set_output(fg->NewCNode(inputs)); \ fg->set_output(fg->NewCNode(inputs)); \
return fg; \ return fg; \
} \ } \
static auto helper_bprop_##name = RegisterPrimitiveBpropHelper(primitive, GetBprop##name); static auto helper_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetBprop##name);
} // namespace graph_bprop } // namespace graph_bprop
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_META_FUNC_GRAPH_H_ #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_META_FUNC_GRAPH_H_

View File

@ -62,7 +62,7 @@ FuncGraphPtr TransposeBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
fg->set_output(NewNode(fg, {MakeTuple(), transpose, zeros_like})); fg->set_output(NewNode(fg, {MakeTuple(), transpose, zeros_like}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, prim::kPrimTranspose, TransposeBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, TransposeBprop);
FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
constexpr size_t expected_arg_size = 4; constexpr size_t expected_arg_size = 4;
@ -90,6 +90,6 @@ FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in
fg->set_output(NewNode(fg, {MakeTuple(), return_node, zeros_like_node})); fg->set_output(NewNode(fg, {MakeTuple(), return_node, zeros_like_node}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(Cast, prim::kPrimCast, CastBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(Cast, CastBprop);
} // namespace graph_bprop } // namespace graph_bprop
} // namespace mindspore } // namespace mindspore

View File

@ -56,7 +56,7 @@ FuncGraphPtr MatMulBprop(const PrimitivePtr &primal, const AbstractBasePtrList &
fg->set_output(NewNode(fg, {MakeTuple(), dx, dw})); fg->set_output(NewNode(fg, {MakeTuple(), dx, dw}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(MatMul, prim::kPrimMatMul, MatMulBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(MatMul, MatMulBprop);
FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -69,7 +69,7 @@ FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
fg->set_output(BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], neg_dout)); fg->set_output(BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], neg_dout));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(Sub, prim::kPrimSub, SubBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(Sub, SubBprop);
FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -80,7 +80,7 @@ FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], parameters[kIndex3])); BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], parameters[kIndex3]));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(Add, prim::kPrimAdd, AddBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(Add, AddBprop);
FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -94,7 +94,7 @@ FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
fg->set_output(NewNode(fg, {MakeTuple(), out1, out2})); fg->set_output(NewNode(fg, {MakeTuple(), out1, out2}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, AssignAddBprop);
FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto neg_grad = Neg(); auto neg_grad = Neg();
@ -106,7 +106,7 @@ FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
fg->set_output(NewNode(fg, {MakeTuple(), dx})); fg->set_output(NewNode(fg, {MakeTuple(), dx}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(Neg, prim::kPrimNeg, NegBprop, 1); REGISTER_PRIMITIVE_BPROP_IMPL(Neg, NegBprop);
FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -118,6 +118,6 @@ FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
fg->set_output(NewNode(fg, {MakeTuple(), dx, dy})); fg->set_output(NewNode(fg, {MakeTuple(), dx, dy}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, prim::kPrimLogicalOr, LogicalOrBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, LogicalOrBprop);
} // namespace graph_bprop } // namespace graph_bprop
} // namespace mindspore } // namespace mindspore

View File

@ -35,7 +35,7 @@ FuncGraphPtr ReluBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in
fg->set_output(NewNode(fg, {MakeTuple(), dx})); fg->set_output(NewNode(fg, {MakeTuple(), dx}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(ReLU, prim::kPrimReLU, ReluBprop, 1); REGISTER_PRIMITIVE_BPROP_IMPL(ReLU, ReluBprop);
FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -60,7 +60,7 @@ FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList &
fg->set_output(NewNode(fg, {MakeTuple(), dx, dw})); fg->set_output(NewNode(fg, {MakeTuple(), dx, dw}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(Conv2D, prim::kPrimConv2D, Conv2DBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(Conv2D, Conv2DBprop);
FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -87,7 +87,7 @@ FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
fg->set_output(NewNode(fg, {MakeTuple(), dx, d_gamma, d_beta})); fg->set_output(NewNode(fg, {MakeTuple(), dx, d_gamma, d_beta}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(LayerNorm, prim::kPrimLayerNorm, LayerNormBprop, 3); REGISTER_PRIMITIVE_BPROP_IMPL(LayerNorm, LayerNormBprop);
FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -103,7 +103,7 @@ FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList
fg->set_output(NewNode(fg, {MakeTuple(), dx})); fg->set_output(NewNode(fg, {MakeTuple(), dx}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(MaxPool, prim::kPrimMaxPool, MaxPoolBprop, 1); REGISTER_PRIMITIVE_BPROP_IMPL(MaxPool, MaxPoolBprop);
FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -140,7 +140,7 @@ FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
NewNode(fg, {MakeTuple(), dx, dscale, dbias, ZerosLikeFunction(fg, mean), ZerosLikeFunction(fg, variance)})); NewNode(fg, {MakeTuple(), dx, dscale, dbias, ZerosLikeFunction(fg, mean), ZerosLikeFunction(fg, variance)}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormBprop, 5); REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, BatchNormBprop);
FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -154,7 +154,7 @@ FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList
fg->set_output(NewNode(fg, {MakeTuple(), dout, bais_add_grad})); fg->set_output(NewNode(fg, {MakeTuple(), dout, bais_add_grad}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddBprop, 2); REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, BiasAddBprop);
FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) { FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs); auto fg = NewGraph(input_abs);
@ -169,6 +169,6 @@ FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in
fg->set_output(NewNode(fg, {MakeTuple(), dx})); fg->set_output(NewNode(fg, {MakeTuple(), dx}));
return fg; return fg;
} }
REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, prim::kPrimGeLU, GeLUBprop, 1); REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, GeLUBprop);
} // namespace graph_bprop } // namespace graph_bprop
} // namespace mindspore } // namespace mindspore

View File

@ -15,6 +15,7 @@
*/ */
#include "frontend/optimizer/ad/bprop_utils.h" #include "frontend/optimizer/ad/bprop_utils.h"
#include <string> #include <string>
#include <regex> #include <regex>
#include <utility> #include <utility>
@ -27,6 +28,9 @@
#include "utils/system/sha256.h" #include "utils/system/sha256.h"
#include "mindspore/core/load_mindir/load_model.h" #include "mindspore/core/load_mindir/load_model.h"
#include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/resolve.h"
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
#include "frontend/optimizer/expander.h"
#include "include/common/debug/dump_proto.h" #include "include/common/debug/dump_proto.h"
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h" #include "frontend/optimizer/irpass.h"
@ -305,7 +309,7 @@ bool CheckMindir(const py::object &obj) {
} }
#endif #endif
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) { FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
// Set a child scope named "grad'PrimitiveName'" for the bprop function, // Set a child scope named "grad'PrimitiveName'" for the bprop function,
// and add "Gradients" to the front. // and add "Gradients" to the front.
static const std::string gradients_scope = "Gradients/"; static const std::string gradients_scope = "Gradients/";
@ -319,9 +323,17 @@ FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr
FuncGraphPtr func_graph = nullptr; FuncGraphPtr func_graph = nullptr;
if (common::GetEnv("MS_DEV_GET_PYTHON_BPROP") != "1") { if (common::GetEnv("MS_DEV_GET_PYTHON_BPROP") != "1") {
const auto &bprop_impl_map = graph_bprop::GetPrimitiveBpropImplMap(); const auto &bprop_impl_map = graph_bprop::GetPrimitiveBpropImplMap();
auto iter = bprop_impl_map.find(prim); auto iter = bprop_impl_map.find(prim->name());
if (iter != bprop_impl_map.end()) { if (iter != bprop_impl_map.end()) {
func_graph = iter->second(prim); std::vector<AnfNodePtr> node_lists = cnode->inputs();
auto forward_inputs_size = cnode->inputs().size() - 1;
for (size_t i = 1; i < node_lists.size(); i++) {
auto inputi = node_lists[i];
if (HasAbstractMonad(inputi)) {
--forward_inputs_size;
}
}
func_graph = iter->second(prim, forward_inputs_size);
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
func_graph->set_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop, true); func_graph->set_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop, true);
if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) { if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) {

View File

@ -31,7 +31,8 @@ void ExportBpropToMindir(const py::object &obj, bool force_update);
bool CheckMindir(const py::object &obj); bool CheckMindir(const py::object &obj);
#endif #endif
// Get bprop function of a primitive. // Get bprop function of a primitive.
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr); FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr,
const CNodePtr &cnode = nullptr);
} // namespace ad } // namespace ad
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_BPROP_MANAGER_H_ #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_BPROP_MANAGER_H_

View File

@ -156,7 +156,7 @@ class KPrim {
private: private:
FuncGraphPtr GetFprop(const PrimitivePtr &prim) const; FuncGraphPtr GetFprop(const PrimitivePtr &prim) const;
FuncGraphPtr GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node, FuncGraphPtr GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
const pipeline::ResourceBasePtr &resources); const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode = nullptr);
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const; FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const; FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
// Given a bprop rule, do the K mapping. // Given a bprop rule, do the K mapping.

View File

@ -30,6 +30,7 @@
#include "pipeline/jit/resource.h" #include "pipeline/jit/resource.h"
#include "frontend/optimizer/ad/dfunctor.h" #include "frontend/optimizer/ad/dfunctor.h"
#include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/composite.h"
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
#include "include/common/utils/utils.h" #include "include/common/utils/utils.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
@ -37,6 +38,7 @@
#include "pipeline/jit/debug/trace.h" #include "pipeline/jit/debug/trace.h"
#include "utils/anf_utils.h" #include "utils/anf_utils.h"
#include "frontend/optimizer/ad/bprop_utils.h" #include "frontend/optimizer/ad/bprop_utils.h"
#include "frontend/optimizer/expander.h"
namespace mindspore { namespace mindspore {
namespace ad { namespace ad {
@ -47,7 +49,7 @@ constexpr char kLiftedUserDataKey[] = "lifted_from_fv";
} // namespace } // namespace
FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node, FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
const pipeline::ResourceBasePtr &resources) { const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(value_node);
auto iter = bprop_registry_.find(prim); auto iter = bprop_registry_.find(prim);
@ -55,7 +57,7 @@ FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &v
return iter->second; return iter->second;
} }
FuncGraphPtr bprop_fg = GetBprop(prim, resources); FuncGraphPtr bprop_fg = GetBprop(prim, resources, cnode);
if (bprop_fg != nullptr) { if (bprop_fg != nullptr) {
// Set bprop_g graph cache // Set bprop_g graph cache
bprop_registry_[prim] = bprop_fg; bprop_registry_[prim] = bprop_fg;
@ -218,7 +220,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
} }
bprop_fg = BpropCut(value_node, resources); bprop_fg = BpropCut(value_node, resources);
} else { } else {
bprop_fg = GetPrimBprop(prim, value_node, resources); bprop_fg = GetPrimBprop(prim, value_node, resources, cnode);
} }
SetDumpFlag(prim, bprop_fg); SetDumpFlag(prim, bprop_fg);

View File

@ -23,6 +23,7 @@
#include "mindspore/core/utils/anf_utils.h" #include "mindspore/core/utils/anf_utils.h"
#include "frontend/parallel/auto_parallel/costmodel.h" #include "frontend/parallel/auto_parallel/costmodel.h"
#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/operator/ops_front_infer_function.h"
#include "pybind_api/ir/primitive_py.h" #include "pybind_api/ir/primitive_py.h"
#include "common/graph_kernel/adapter/expander.h" #include "common/graph_kernel/adapter/expander.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
@ -54,6 +55,12 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
if (primitive == nullptr || primitive->isa<PrimitivePy>()) { if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
continue; continue;
} }
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
continue;
}
if (primitive->isa<prim::DoSignaturePrimitive>()) {
continue;
}
parallel::OperatorAttrs attrs; parallel::OperatorAttrs attrs;
const auto iter = op2attrs.find(primitive->name()); const auto iter = op2attrs.find(primitive->name());
if (iter != op2attrs.end()) { if (iter != op2attrs.end()) {

View File

@ -27,17 +27,20 @@ namespace expander {
namespace bprop { namespace bprop {
bool BpropExpander::Run(const CNodePtr &cnode) { bool BpropExpander::Run(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(outputs_);
MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope(); MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
bool ret = true; bool ret = true;
outputs_->clear(); if (outputs_ != nullptr) {
outputs_->clear();
}
try { try {
ret = RunBprop(cnode); ret = RunBprop(cnode);
} catch (const std::exception &e) { } catch (const std::exception &e) {
auto node_name = AnfUtils::GetCNodeName(cnode); auto node_name = AnfUtils::GetCNodeName(cnode);
MS_LOG(DEBUG) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]"; MS_LOG(DEBUG) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
MS_LOG(INFO) << "Python bprop will be used for \"" << node_name << "\""; MS_LOG(INFO) << "Python bprop will be used for \"" << node_name << "\"";
outputs_->clear(); if (outputs_ != nullptr) {
outputs_->clear();
}
ret = false; ret = false;
} }
MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope(); MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
@ -56,51 +59,55 @@ const std::vector<size_t> &BpropExpander::GetUnusedInputs(const CNodePtr &cnode)
return handle->unused_inputs; return handle->unused_inputs;
} }
NodePtrList BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) { void BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
NodePtrList nodes; input_nodes_.reserve(cnode->size());
nodes.reserve(cnode->size()); (void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_),
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes),
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); }); [ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); });
return nodes; }
std::unique_ptr<BpropIRBuilder> BpropExpander::CreateIRBuilder(const std::string &name, const CNodePtr &cnode,
const std::shared_ptr<CppInfer> &infer) {
return std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
} }
bool BpropExpander::RunBprop(const CNodePtr &cnode) { bool BpropExpander::RunBprop(const CNodePtr &cnode) {
auto infer = std::make_shared<CppInfer>(); auto infer = std::make_shared<CppInfer>();
auto name = AnfUtils::GetCNodeName(cnode); auto name = AnfUtils::GetCNodeName(cnode);
auto ir_builder = std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer); auto ir_builder = CreateIRBuilder(name, cnode, infer);
auto inputs = ExtractInputs(cnode, ir_builder.get()); ExtractInputs(cnode, ir_builder.get());
auto &attrs = GetCNodePrimitive(cnode)->attrs(); auto &attrs = GetCNodePrimitive(cnode)->attrs();
auto handle = GetBpropHandle(name); auto handle = GetBpropHandle(name);
if (handle == nullptr) { if (handle == nullptr) {
MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander."; MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
return false; return false;
} }
auto output_nodes = ir_builder->Run(inputs, attrs, *handle); output_nodes_ = ir_builder->Run(input_nodes_, attrs, *handle);
if (output_nodes.empty()) { if (output_nodes_.empty()) {
MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty."; MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
return false; return false;
} }
outputs_->reserve(output_nodes.size()); PostProcess();
(void)std::transform(output_nodes.cbegin(), output_nodes.cend(), std::back_inserter(*outputs_), DumpResult(name);
[](const NodePtr &node) { input_nodes_.clear();
auto cnode = node->get<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return cnode;
});
PostProcess(inputs);
DumpResult(name, inputs);
return true; return true;
} }
void BpropExpander::PostProcess(const NodePtrList &inputs) const { void BpropExpander::PostProcess() const {
outputs_->reserve(output_nodes_.size());
(void)std::transform(output_nodes_.cbegin(), output_nodes_.cend(), std::back_inserter(*outputs_),
[](const NodePtr &node) {
auto cnode = node->get<CNodePtr>();
return cnode;
});
std::set<AnfNodePtr> visited; std::set<AnfNodePtr> visited;
// do not visit the inputs again. // do not visit the inputs again.
std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); }); std::for_each(input_nodes_.cbegin(), input_nodes_.cend(),
[&visited](const NodePtr &node) { visited.insert(node->get()); });
std::queue<CNodePtr> que; std::queue<CNodePtr> que;
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); }); std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
AnfNodePtr dout = inputs.back()->get(); AnfNodePtr dout = input_nodes_.back()->get();
while (!que.empty()) { while (!que.empty()) {
auto node = que.front(); auto node = que.front();
que.pop(); que.pop();
@ -138,7 +145,7 @@ void BpropExpander::PostProcess(const NodePtrList &inputs) const {
} }
} }
void BpropExpander::DumpResult(const std::string &name, const NodePtrList &inputs) const { void BpropExpander::DumpResult(const std::string &name) const {
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on"); static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
if (!dump_result) { if (!dump_result) {
return; return;
@ -146,7 +153,7 @@ void BpropExpander::DumpResult(const std::string &name, const NodePtrList &input
auto fg = std::make_shared<FuncGraph>(); auto fg = std::make_shared<FuncGraph>();
std::map<AnfNodePtr, AnfNodePtr> node_map; std::map<AnfNodePtr, AnfNodePtr> node_map;
CNodePtrList newcnodes; CNodePtrList newcnodes;
for (auto &inp : inputs) { for (auto &inp : input_nodes_) {
auto p = fg->add_parameter(); auto p = fg->add_parameter();
p->set_abstract(inp->get()->abstract()); p->set_abstract(inp->get()->abstract());
node_map[inp->get()] = p; node_map[inp->get()] = p;
@ -211,6 +218,45 @@ void BpropExpander::DumpResult(const std::string &name, const NodePtrList &input
} }
} }
void BpropExpanderInGraphMode::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
input_nodes_.reserve(cnode->size());
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_),
[ir_builder, this](const AnfNodePtr &no) {
auto p = this->fg_->add_parameter();
p->set_abstract(no->abstract());
return std::make_shared<Node>(p, ir_builder);
});
}
std::unique_ptr<BpropIRBuilder> BpropExpanderInGraphMode::CreateIRBuilder(const std::string &name,
const CNodePtr &cnode,
const std::shared_ptr<CppInfer> &infer) {
fg_ = std::make_shared<FuncGraph>();
return std::make_unique<BpropIRBuilder>(name, fg_, infer);
}
void BpropExpanderInGraphMode::PostProcess() const {
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abs;
(void)std::transform(output_nodes_.cbegin(), output_nodes_.cend(), std::back_inserter(new_outputs),
[&abs](const NodePtr &node) {
abs.push_back(node->get()->abstract());
return node->get();
});
auto mt = fg_->NewCNode(new_outputs);
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
fg_->set_output(mt);
}
void BpropExpanderInGraphMode::DumpResult(const std::string &name) const {
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
if (!dump_result) {
return;
}
DumpIR("bprop/bprop_expander_" + name + ".ir", fg_, true);
}
#ifdef _MSC_VER #ifdef _MSC_VER
void RegGradArrayOps(); void RegGradArrayOps();
void RegGradClipOps(); void RegGradClipOps();

View File

@ -36,20 +36,39 @@ class BpropExpander {
bool Run(const CNodePtr &cnode); bool Run(const CNodePtr &cnode);
const std::vector<size_t> &GetUnusedInputs(const CNodePtr &cnode) const; const std::vector<size_t> &GetUnusedInputs(const CNodePtr &cnode) const;
private: protected:
bool RunBprop(const CNodePtr &cnode); bool RunBprop(const CNodePtr &cnode);
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder); virtual void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
virtual std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode,
const std::shared_ptr<CppInfer> &infer);
const BpropHandle *GetBpropHandle(const std::string &name) const { const BpropHandle *GetBpropHandle(const std::string &name) const {
return BpropIRBuilderFactory::Instance().GetBuilder(name); return BpropIRBuilderFactory::Instance().GetBuilder(name);
} }
void PostProcess(const NodePtrList &inputs) const; virtual void PostProcess() const;
void DumpResult(const std::string &name, const NodePtrList &inputs) const; virtual void DumpResult(const std::string &name) const;
NodePtrList input_nodes_;
private: // outputs_ must be CNodePtrList, but output_nodes_ may not necessary. output_nodes_ are used to
// create bprop func_graph in graph_mode.
NodePtrList output_nodes_;
CNodePtrList *outputs_{nullptr}; CNodePtrList *outputs_{nullptr};
UserType *users_{nullptr}; UserType *users_{nullptr};
}; };
class BpropExpanderInGraphMode : public BpropExpander {
public:
BpropExpanderInGraphMode() {}
~BpropExpanderInGraphMode() = default;
FuncGraphPtr GetGraph() { return fg_; }
protected:
FuncGraphPtr fg_{nullptr};
void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) override;
std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode,
const std::shared_ptr<CppInfer> &infer) override;
void PostProcess() const override;
void DumpResult(const std::string &name) const override;
};
#ifdef _MSC_VER #ifdef _MSC_VER
class WinBpropRegister { class WinBpropRegister {
public: public:

View File

@ -143,6 +143,39 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
return Emit(prim::kZerosLike, {Tensor(0)}); return Emit(prim::kZerosLike, {Tensor(0)});
} }
} }
if (node->isa<Parameter>()) {
if (node->get()->abstract()->isa<abstract::AbstractTensor>()) {
return Emit(prim::kZerosLike, {node});
}
if (node->get()->abstract()->isa<abstract::AbstractTuple>()) {
NodePtrList list;
auto abstract_tuple = node->get()->abstract()->cast<abstract::AbstractTuplePtr>();
for (auto &e : abstract_tuple->elements()) {
if (e->isa<abstract::AbstractTensor>()) {
auto shape = e->BuildShape()->cast<abstract::ShapePtr>()->shape();
auto type = e->BuildType()->cast<TensorTypePtr>()->element();
list.emplace_back(Emit("Zeros", {EmitValue(MakeValue(shape)), EmitValue(type)}));
} else if (e->isa<abstract::AbstractScalar>()) {
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
} else {
MS_LOG(WARNING) << "ZerosLike got UNKNOWN TYPE: " << e->ToString();
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
}
}
return MakeTuple(list);
}
if (node->get()->abstract()->isa<abstract::AbstractMonad>()) {
return Emit(prim::kZerosLike, {Tensor(0)});
}
auto v = node->get()->abstract()->BuildValue();
if (v->isa<Scalar>() || v->isa<Type>()) {
return Emit(prim::kZerosLike, {Tensor(0, v->type())});
}
if (v->isa<ValueSequence>()) {
auto sh = GetValue<std::vector<int64_t>>(v);
return Emit(prim::kZerosLike, {Tensor(sh)});
}
}
return Emit(prim::kZerosLike, {node}); return Emit(prim::kZerosLike, {node});
} }