forked from mindspore-Ecosystem/mindspore
!46797 lazy_expand_funcgraph adapt
Merge pull request !46797 from ZengZitao/lazy_expand_func
This commit is contained in:
commit
17755b5a61
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "frontend/operator/graph_bprop/bprop_expander_meta_func_graph.h"
|
||||
#include "frontend/operator/graph_bprop/utils.h"
|
||||
#include "frontend/operator/graph_bprop/ops_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace graph_bprop {
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Sin);
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,92 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "frontend/operator/graph_bprop/bprop_meta_func_graph.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
#include "frontend/optimizer/expander.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace graph_bprop {
|
||||
constexpr int64_t TWO = 2;
|
||||
constexpr int64_t ONE = 1;
|
||||
class BpropExpanderMetaFuncGraph : public BpropMetaFuncGraph {
|
||||
public:
|
||||
explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal) : BpropMetaFuncGraph(primal->name(), primal) {}
|
||||
~BpropExpanderMetaFuncGraph() override = default;
|
||||
MS_DECLARE_PARENT(BpropExpanderMetaFuncGraph, BpropMetaFuncGraph);
|
||||
FuncGraphPtr BpropExpanderFunc(const AbstractBasePtrList &args_spec_list) {
|
||||
int64_t list_size = SizeToLong(args_spec_list.size());
|
||||
auto bprop_fg = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(primal_));
|
||||
for (int64_t i = 0; i < list_size - TWO; ++i) {
|
||||
auto abs_i = args_spec_list[i];
|
||||
auto x = bprop_fg->add_parameter();
|
||||
x->set_abstract(args_spec_list[i]);
|
||||
x->abstract()->set_value(args_spec_list[i]->BuildValue());
|
||||
(void)grads.emplace_back(x);
|
||||
}
|
||||
auto out = bprop_fg->add_parameter();
|
||||
|
||||
out->set_abstract(args_spec_list[list_size - TWO]);
|
||||
(void)grads.emplace_back(out);
|
||||
|
||||
auto dout = bprop_fg->NewCNode({NewValueNode(prim::kPrimZerosLike), out});
|
||||
dout->set_abstract(args_spec_list[list_size - ONE]);
|
||||
(void)grads.emplace_back(dout);
|
||||
auto newcnode = bprop_fg->NewCNode(grads);
|
||||
expander::bprop::BpropExpanderInGraphMode be;
|
||||
if (be.Run(newcnode)) {
|
||||
bprop_fg = be.GetGraph();
|
||||
} else {
|
||||
MS_LOG(WARNING) << "Bprop FuncGraph Create Failed!";
|
||||
}
|
||||
|
||||
(void)mindspore::opt::ConvertPrimToPrimPy(bprop_fg);
|
||||
|
||||
return bprop_fg;
|
||||
}
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) override {
|
||||
return BpropExpanderFunc(input_abs);
|
||||
}
|
||||
};
|
||||
FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size) {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
auto meta_graph = std::make_shared<BpropExpanderMetaFuncGraph>(primal);
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)};
|
||||
for (size_t i = 0; i < forward_inputs_size; ++i) {
|
||||
(void)inputs.emplace_back(fg->add_parameter());
|
||||
}
|
||||
(void)inputs.emplace_back(fg->add_parameter());
|
||||
(void)inputs.emplace_back(fg->add_parameter());
|
||||
fg->set_output(fg->NewCNode(inputs));
|
||||
return fg;
|
||||
}
|
||||
|
||||
#define STR(s) #s
|
||||
#define REGISTER_EXPANDER_BPROP_IMPL(name) \
|
||||
static auto helper_expand_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetExpandBprop);
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_EXPANDER_META_FUNC_GRAPH_H_
|
|
@ -38,23 +38,23 @@ class BpropMetaFuncGraph : public MetaFuncGraph {
|
|||
PrimitivePtr primal_;
|
||||
};
|
||||
|
||||
using BpropFunction = std::function<FuncGraphPtr(const PrimitivePtr &)>;
|
||||
using PrimitiveBpropImplMap = mindspore::HashMap<PrimitivePtr, BpropFunction, PrimitiveHasher, PrimitiveEqual>;
|
||||
using BpropFunction = std::function<FuncGraphPtr(const PrimitivePtr &, const size_t)>;
|
||||
using PrimitiveBpropImplMap = mindspore::HashMap<std::string, BpropFunction>;
|
||||
|
||||
PrimitiveBpropImplMap &GetPrimitiveBpropImplMap();
|
||||
|
||||
class RegisterPrimitiveBpropHelper {
|
||||
public:
|
||||
RegisterPrimitiveBpropHelper(const PrimitivePtr &primitive, const BpropFunction &bprop_fn) {
|
||||
RegisterPrimitiveBpropHelper(const std::string &op_name, const BpropFunction &bprop_fn) {
|
||||
auto &prim_bprop_impl_map = GetPrimitiveBpropImplMap();
|
||||
prim_bprop_impl_map[primitive] = bprop_fn;
|
||||
prim_bprop_impl_map[op_name] = bprop_fn;
|
||||
}
|
||||
~RegisterPrimitiveBpropHelper() = default;
|
||||
};
|
||||
|
||||
#define STR(s) #s
|
||||
|
||||
#define REGISTER_PRIMITIVE_BPROP_IMPL(name, primitive, bprop_fn, forward_inputs_size) \
|
||||
#define REGISTER_PRIMITIVE_BPROP_IMPL(name, bprop_fn) \
|
||||
class BpropMetaFuncGraph##name : public BpropMetaFuncGraph { \
|
||||
public: \
|
||||
explicit BpropMetaFuncGraph##name(const PrimitivePtr &primal) \
|
||||
|
@ -66,7 +66,7 @@ class RegisterPrimitiveBpropHelper {
|
|||
return bprop_fn(primal_, input_abs); \
|
||||
} \
|
||||
}; \
|
||||
FuncGraphPtr GetBprop##name(const PrimitivePtr &primal) { \
|
||||
FuncGraphPtr GetBprop##name(const PrimitivePtr &primal, const size_t forward_inputs_size) { \
|
||||
auto fg = std::make_shared<FuncGraph>(); \
|
||||
auto meta_graph = std::make_shared<BpropMetaFuncGraph##name>(primal); \
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)}; \
|
||||
|
@ -78,7 +78,7 @@ class RegisterPrimitiveBpropHelper {
|
|||
fg->set_output(fg->NewCNode(inputs)); \
|
||||
return fg; \
|
||||
} \
|
||||
static auto helper_bprop_##name = RegisterPrimitiveBpropHelper(primitive, GetBprop##name);
|
||||
static auto helper_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetBprop##name);
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_GRAPH_BPROP_BPROP_META_FUNC_GRAPH_H_
|
||||
|
|
|
@ -62,7 +62,7 @@ FuncGraphPtr TransposeBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), transpose, zeros_like}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, prim::kPrimTranspose, TransposeBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, TransposeBprop);
|
||||
|
||||
FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
constexpr size_t expected_arg_size = 4;
|
||||
|
@ -90,6 +90,6 @@ FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), return_node, zeros_like_node}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Cast, prim::kPrimCast, CastBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Cast, CastBprop);
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -56,7 +56,7 @@ FuncGraphPtr MatMulBprop(const PrimitivePtr &primal, const AbstractBasePtrList &
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx, dw}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(MatMul, prim::kPrimMatMul, MatMulBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(MatMul, MatMulBprop);
|
||||
|
||||
FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -69,7 +69,7 @@ FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
|
|||
fg->set_output(BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], neg_dout));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Sub, prim::kPrimSub, SubBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Sub, SubBprop);
|
||||
|
||||
FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -80,7 +80,7 @@ FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
|
|||
BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], parameters[kIndex3]));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Add, prim::kPrimAdd, AddBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Add, AddBprop);
|
||||
|
||||
FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -94,7 +94,7 @@ FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), out1, out2}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, AssignAddBprop);
|
||||
|
||||
FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto neg_grad = Neg();
|
||||
|
@ -106,7 +106,7 @@ FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Neg, prim::kPrimNeg, NegBprop, 1);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Neg, NegBprop);
|
||||
|
||||
FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -118,6 +118,6 @@ FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx, dy}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, prim::kPrimLogicalOr, LogicalOrBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, LogicalOrBprop);
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,7 +35,7 @@ FuncGraphPtr ReluBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(ReLU, prim::kPrimReLU, ReluBprop, 1);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(ReLU, ReluBprop);
|
||||
|
||||
FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -60,7 +60,7 @@ FuncGraphPtr Conv2DBprop(const PrimitivePtr &primal, const AbstractBasePtrList &
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx, dw}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Conv2D, prim::kPrimConv2D, Conv2DBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(Conv2D, Conv2DBprop);
|
||||
|
||||
FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -87,7 +87,7 @@ FuncGraphPtr LayerNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx, d_gamma, d_beta}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(LayerNorm, prim::kPrimLayerNorm, LayerNormBprop, 3);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(LayerNorm, LayerNormBprop);
|
||||
|
||||
FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -103,7 +103,7 @@ FuncGraphPtr MaxPoolBprop(const PrimitivePtr &primal, const AbstractBasePtrList
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(MaxPool, prim::kPrimMaxPool, MaxPoolBprop, 1);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(MaxPool, MaxPoolBprop);
|
||||
|
||||
FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -140,7 +140,7 @@ FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
|
|||
NewNode(fg, {MakeTuple(), dx, dscale, dbias, ZerosLikeFunction(fg, mean), ZerosLikeFunction(fg, variance)}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormBprop, 5);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, BatchNormBprop);
|
||||
|
||||
FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -154,7 +154,7 @@ FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dout, bais_add_grad}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddBprop, 2);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, BiasAddBprop);
|
||||
|
||||
FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
|
||||
auto fg = NewGraph(input_abs);
|
||||
|
@ -169,6 +169,6 @@ FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &in
|
|||
fg->set_output(NewNode(fg, {MakeTuple(), dx}));
|
||||
return fg;
|
||||
}
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, prim::kPrimGeLU, GeLUBprop, 1);
|
||||
REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, GeLUBprop);
|
||||
} // namespace graph_bprop
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
|
||||
#include "frontend/optimizer/ad/bprop_utils.h"
|
||||
|
||||
#include <string>
|
||||
#include <regex>
|
||||
#include <utility>
|
||||
|
@ -27,6 +28,9 @@
|
|||
#include "utils/system/sha256.h"
|
||||
#include "mindspore/core/load_mindir/load_model.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop_irbuilder.h"
|
||||
#include "frontend/optimizer/expander.h"
|
||||
#include "include/common/debug/dump_proto.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
|
@ -305,7 +309,7 @@ bool CheckMindir(const py::object &obj) {
|
|||
}
|
||||
#endif
|
||||
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
|
||||
// Set a child scope named "grad'PrimitiveName'" for the bprop function,
|
||||
// and add "Gradients" to the front.
|
||||
static const std::string gradients_scope = "Gradients/";
|
||||
|
@ -319,9 +323,17 @@ FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr
|
|||
FuncGraphPtr func_graph = nullptr;
|
||||
if (common::GetEnv("MS_DEV_GET_PYTHON_BPROP") != "1") {
|
||||
const auto &bprop_impl_map = graph_bprop::GetPrimitiveBpropImplMap();
|
||||
auto iter = bprop_impl_map.find(prim);
|
||||
auto iter = bprop_impl_map.find(prim->name());
|
||||
if (iter != bprop_impl_map.end()) {
|
||||
func_graph = iter->second(prim);
|
||||
std::vector<AnfNodePtr> node_lists = cnode->inputs();
|
||||
auto forward_inputs_size = cnode->inputs().size() - 1;
|
||||
for (size_t i = 1; i < node_lists.size(); i++) {
|
||||
auto inputi = node_lists[i];
|
||||
if (HasAbstractMonad(inputi)) {
|
||||
--forward_inputs_size;
|
||||
}
|
||||
}
|
||||
func_graph = iter->second(prim, forward_inputs_size);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
func_graph->set_flag(mindspore::kFuncGraphFlagMetaFuncGraphBprop, true);
|
||||
if (GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP)) {
|
||||
|
|
|
@ -31,7 +31,8 @@ void ExportBpropToMindir(const py::object &obj, bool force_update);
|
|||
bool CheckMindir(const py::object &obj);
|
||||
#endif
|
||||
// Get bprop function of a primitive.
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr);
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources = nullptr,
|
||||
const CNodePtr &cnode = nullptr);
|
||||
} // namespace ad
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_AD_BPROP_MANAGER_H_
|
||||
|
|
|
@ -156,7 +156,7 @@ class KPrim {
|
|||
private:
|
||||
FuncGraphPtr GetFprop(const PrimitivePtr &prim) const;
|
||||
FuncGraphPtr GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
|
||||
const pipeline::ResourceBasePtr &resources);
|
||||
const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode = nullptr);
|
||||
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
|
||||
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const;
|
||||
// Given a bprop rule, do the K mapping.
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "pipeline/jit/resource.h"
|
||||
#include "frontend/optimizer/ad/dfunctor.h"
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -37,6 +38,7 @@
|
|||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "frontend/optimizer/ad/bprop_utils.h"
|
||||
#include "frontend/optimizer/expander.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ad {
|
||||
|
@ -47,7 +49,7 @@ constexpr char kLiftedUserDataKey[] = "lifted_from_fv";
|
|||
} // namespace
|
||||
|
||||
FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
|
||||
const pipeline::ResourceBasePtr &resources) {
|
||||
const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
auto iter = bprop_registry_.find(prim);
|
||||
|
@ -55,7 +57,7 @@ FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &v
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
FuncGraphPtr bprop_fg = GetBprop(prim, resources);
|
||||
FuncGraphPtr bprop_fg = GetBprop(prim, resources, cnode);
|
||||
if (bprop_fg != nullptr) {
|
||||
// Set bprop_g graph cache
|
||||
bprop_registry_[prim] = bprop_fg;
|
||||
|
@ -218,7 +220,7 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
}
|
||||
bprop_fg = BpropCut(value_node, resources);
|
||||
} else {
|
||||
bprop_fg = GetPrimBprop(prim, value_node, resources);
|
||||
bprop_fg = GetPrimBprop(prim, value_node, resources, cnode);
|
||||
}
|
||||
|
||||
SetDumpFlag(prim, bprop_fg);
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mindspore/core/utils/anf_utils.h"
|
||||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/operator/ops_front_infer_function.h"
|
||||
#include "pybind_api/ir/primitive_py.h"
|
||||
#include "common/graph_kernel/adapter/expander.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -54,6 +55,12 @@ bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
|||
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
|
||||
continue;
|
||||
}
|
||||
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
|
||||
continue;
|
||||
}
|
||||
if (primitive->isa<prim::DoSignaturePrimitive>()) {
|
||||
continue;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
const auto iter = op2attrs.find(primitive->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
|
|
|
@ -27,17 +27,20 @@ namespace expander {
|
|||
namespace bprop {
|
||||
bool BpropExpander::Run(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(outputs_);
|
||||
MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
|
||||
bool ret = true;
|
||||
outputs_->clear();
|
||||
if (outputs_ != nullptr) {
|
||||
outputs_->clear();
|
||||
}
|
||||
try {
|
||||
ret = RunBprop(cnode);
|
||||
} catch (const std::exception &e) {
|
||||
auto node_name = AnfUtils::GetCNodeName(cnode);
|
||||
MS_LOG(DEBUG) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
|
||||
MS_LOG(INFO) << "Python bprop will be used for \"" << node_name << "\"";
|
||||
outputs_->clear();
|
||||
if (outputs_ != nullptr) {
|
||||
outputs_->clear();
|
||||
}
|
||||
ret = false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
|
||||
|
@ -56,51 +59,55 @@ const std::vector<size_t> &BpropExpander::GetUnusedInputs(const CNodePtr &cnode)
|
|||
return handle->unused_inputs;
|
||||
}
|
||||
|
||||
NodePtrList BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
NodePtrList nodes;
|
||||
nodes.reserve(cnode->size());
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes),
|
||||
void BpropExpander::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
input_nodes_.reserve(cnode->size());
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_),
|
||||
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); });
|
||||
return nodes;
|
||||
}
|
||||
|
||||
std::unique_ptr<BpropIRBuilder> BpropExpander::CreateIRBuilder(const std::string &name, const CNodePtr &cnode,
|
||||
const std::shared_ptr<CppInfer> &infer) {
|
||||
return std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
|
||||
}
|
||||
|
||||
bool BpropExpander::RunBprop(const CNodePtr &cnode) {
|
||||
auto infer = std::make_shared<CppInfer>();
|
||||
auto name = AnfUtils::GetCNodeName(cnode);
|
||||
auto ir_builder = std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
|
||||
auto inputs = ExtractInputs(cnode, ir_builder.get());
|
||||
auto ir_builder = CreateIRBuilder(name, cnode, infer);
|
||||
ExtractInputs(cnode, ir_builder.get());
|
||||
auto &attrs = GetCNodePrimitive(cnode)->attrs();
|
||||
auto handle = GetBpropHandle(name);
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
|
||||
return false;
|
||||
}
|
||||
auto output_nodes = ir_builder->Run(inputs, attrs, *handle);
|
||||
if (output_nodes.empty()) {
|
||||
output_nodes_ = ir_builder->Run(input_nodes_, attrs, *handle);
|
||||
if (output_nodes_.empty()) {
|
||||
MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
|
||||
return false;
|
||||
}
|
||||
outputs_->reserve(output_nodes.size());
|
||||
(void)std::transform(output_nodes.cbegin(), output_nodes.cend(), std::back_inserter(*outputs_),
|
||||
[](const NodePtr &node) {
|
||||
auto cnode = node->get<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return cnode;
|
||||
});
|
||||
PostProcess(inputs);
|
||||
DumpResult(name, inputs);
|
||||
PostProcess();
|
||||
DumpResult(name);
|
||||
input_nodes_.clear();
|
||||
return true;
|
||||
}
|
||||
|
||||
void BpropExpander::PostProcess(const NodePtrList &inputs) const {
|
||||
void BpropExpander::PostProcess() const {
|
||||
outputs_->reserve(output_nodes_.size());
|
||||
(void)std::transform(output_nodes_.cbegin(), output_nodes_.cend(), std::back_inserter(*outputs_),
|
||||
[](const NodePtr &node) {
|
||||
auto cnode = node->get<CNodePtr>();
|
||||
return cnode;
|
||||
});
|
||||
std::set<AnfNodePtr> visited;
|
||||
// do not visit the inputs again.
|
||||
std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); });
|
||||
std::for_each(input_nodes_.cbegin(), input_nodes_.cend(),
|
||||
[&visited](const NodePtr &node) { visited.insert(node->get()); });
|
||||
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
AnfNodePtr dout = inputs.back()->get();
|
||||
AnfNodePtr dout = input_nodes_.back()->get();
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
|
@ -138,7 +145,7 @@ void BpropExpander::PostProcess(const NodePtrList &inputs) const {
|
|||
}
|
||||
}
|
||||
|
||||
void BpropExpander::DumpResult(const std::string &name, const NodePtrList &inputs) const {
|
||||
void BpropExpander::DumpResult(const std::string &name) const {
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
if (!dump_result) {
|
||||
return;
|
||||
|
@ -146,7 +153,7 @@ void BpropExpander::DumpResult(const std::string &name, const NodePtrList &input
|
|||
auto fg = std::make_shared<FuncGraph>();
|
||||
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||
CNodePtrList newcnodes;
|
||||
for (auto &inp : inputs) {
|
||||
for (auto &inp : input_nodes_) {
|
||||
auto p = fg->add_parameter();
|
||||
p->set_abstract(inp->get()->abstract());
|
||||
node_map[inp->get()] = p;
|
||||
|
@ -211,6 +218,45 @@ void BpropExpander::DumpResult(const std::string &name, const NodePtrList &input
|
|||
}
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
input_nodes_.reserve(cnode->size());
|
||||
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_),
|
||||
[ir_builder, this](const AnfNodePtr &no) {
|
||||
auto p = this->fg_->add_parameter();
|
||||
p->set_abstract(no->abstract());
|
||||
return std::make_shared<Node>(p, ir_builder);
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<BpropIRBuilder> BpropExpanderInGraphMode::CreateIRBuilder(const std::string &name,
|
||||
const CNodePtr &cnode,
|
||||
const std::shared_ptr<CppInfer> &infer) {
|
||||
fg_ = std::make_shared<FuncGraph>();
|
||||
return std::make_unique<BpropIRBuilder>(name, fg_, infer);
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::PostProcess() const {
|
||||
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtrList abs;
|
||||
(void)std::transform(output_nodes_.cbegin(), output_nodes_.cend(), std::back_inserter(new_outputs),
|
||||
[&abs](const NodePtr &node) {
|
||||
abs.push_back(node->get()->abstract());
|
||||
return node->get();
|
||||
});
|
||||
auto mt = fg_->NewCNode(new_outputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
fg_->set_output(mt);
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::DumpResult(const std::string &name) const {
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
if (!dump_result) {
|
||||
return;
|
||||
}
|
||||
DumpIR("bprop/bprop_expander_" + name + ".ir", fg_, true);
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
void RegGradArrayOps();
|
||||
void RegGradClipOps();
|
||||
|
|
|
@ -36,20 +36,39 @@ class BpropExpander {
|
|||
bool Run(const CNodePtr &cnode);
|
||||
const std::vector<size_t> &GetUnusedInputs(const CNodePtr &cnode) const;
|
||||
|
||||
private:
|
||||
protected:
|
||||
bool RunBprop(const CNodePtr &cnode);
|
||||
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
|
||||
virtual void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
|
||||
virtual std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode,
|
||||
const std::shared_ptr<CppInfer> &infer);
|
||||
const BpropHandle *GetBpropHandle(const std::string &name) const {
|
||||
return BpropIRBuilderFactory::Instance().GetBuilder(name);
|
||||
}
|
||||
void PostProcess(const NodePtrList &inputs) const;
|
||||
void DumpResult(const std::string &name, const NodePtrList &inputs) const;
|
||||
|
||||
private:
|
||||
virtual void PostProcess() const;
|
||||
virtual void DumpResult(const std::string &name) const;
|
||||
NodePtrList input_nodes_;
|
||||
// outputs_ must be CNodePtrList, but output_nodes_ may not necessary. output_nodes_ are used to
|
||||
// create bprop func_graph in graph_mode.
|
||||
NodePtrList output_nodes_;
|
||||
CNodePtrList *outputs_{nullptr};
|
||||
UserType *users_{nullptr};
|
||||
};
|
||||
|
||||
class BpropExpanderInGraphMode : public BpropExpander {
|
||||
public:
|
||||
BpropExpanderInGraphMode() {}
|
||||
~BpropExpanderInGraphMode() = default;
|
||||
FuncGraphPtr GetGraph() { return fg_; }
|
||||
|
||||
protected:
|
||||
FuncGraphPtr fg_{nullptr};
|
||||
void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) override;
|
||||
std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode,
|
||||
const std::shared_ptr<CppInfer> &infer) override;
|
||||
void PostProcess() const override;
|
||||
void DumpResult(const std::string &name) const override;
|
||||
};
|
||||
|
||||
#ifdef _MSC_VER
|
||||
class WinBpropRegister {
|
||||
public:
|
||||
|
|
|
@ -143,6 +143,39 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
|||
return Emit(prim::kZerosLike, {Tensor(0)});
|
||||
}
|
||||
}
|
||||
if (node->isa<Parameter>()) {
|
||||
if (node->get()->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
return Emit(prim::kZerosLike, {node});
|
||||
}
|
||||
if (node->get()->abstract()->isa<abstract::AbstractTuple>()) {
|
||||
NodePtrList list;
|
||||
auto abstract_tuple = node->get()->abstract()->cast<abstract::AbstractTuplePtr>();
|
||||
for (auto &e : abstract_tuple->elements()) {
|
||||
if (e->isa<abstract::AbstractTensor>()) {
|
||||
auto shape = e->BuildShape()->cast<abstract::ShapePtr>()->shape();
|
||||
auto type = e->BuildType()->cast<TensorTypePtr>()->element();
|
||||
list.emplace_back(Emit("Zeros", {EmitValue(MakeValue(shape)), EmitValue(type)}));
|
||||
} else if (e->isa<abstract::AbstractScalar>()) {
|
||||
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
|
||||
} else {
|
||||
MS_LOG(WARNING) << "ZerosLike got UNKNOWN TYPE: " << e->ToString();
|
||||
list.emplace_back(Emit(prim::kZerosLike, {Tensor(0, e->BuildType())}));
|
||||
}
|
||||
}
|
||||
return MakeTuple(list);
|
||||
}
|
||||
if (node->get()->abstract()->isa<abstract::AbstractMonad>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0)});
|
||||
}
|
||||
auto v = node->get()->abstract()->BuildValue();
|
||||
if (v->isa<Scalar>() || v->isa<Type>()) {
|
||||
return Emit(prim::kZerosLike, {Tensor(0, v->type())});
|
||||
}
|
||||
if (v->isa<ValueSequence>()) {
|
||||
auto sh = GetValue<std::vector<int64_t>>(v);
|
||||
return Emit(prim::kZerosLike, {Tensor(sh)});
|
||||
}
|
||||
}
|
||||
return Emit(prim::kZerosLike, {node});
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue