forked from mindspore-Ecosystem/mindspore
!49724 optimize bprop expander compile time in graph mode
Merge pull request !49724 from Gaoxiong/master
This commit is contained in:
commit
390f6e35ba
|
@ -70,7 +70,11 @@ NodePtr Emitter::Emit(const std::string &op_name, const NodePtrList &inputs, con
|
|||
}
|
||||
}
|
||||
}
|
||||
AnfNodePtrList cnode_inputs = {NewValueNode(primc)};
|
||||
return EmitOp(primc, inputs, attrs);
|
||||
}
|
||||
|
||||
NodePtr Emitter::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs, const DAttr &attrs) const {
|
||||
AnfNodePtrList cnode_inputs = {NewValueNode(prim)};
|
||||
cnode_inputs.reserve(inputs.size() + 1);
|
||||
(void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
|
||||
MS_EXCEPTION_IF_NULL(no);
|
||||
|
|
|
@ -18,46 +18,29 @@
|
|||
#include "frontend/operator/graph_bprop/ops_utils.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace graph_bprop {
|
||||
FuncGraphPtr BpropExpanderMetaFuncGraph::BpropExpanderFunc(const AbstractBasePtrList &args_spec_list) {
|
||||
int64_t list_size = SizeToLong(args_spec_list.size());
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(primal_));
|
||||
for (int64_t i = 0; i < list_size - kTwo; ++i) {
|
||||
auto abs_i = args_spec_list[i];
|
||||
auto x = fg->add_parameter();
|
||||
x->set_abstract(args_spec_list[i]);
|
||||
x->abstract()->set_value(args_spec_list[i]->BuildValue());
|
||||
(void)grads.emplace_back(x);
|
||||
}
|
||||
auto out = fg->add_parameter();
|
||||
out->set_abstract(args_spec_list[list_size - kTwo]);
|
||||
(void)grads.emplace_back(out);
|
||||
auto dout = fg->add_parameter();
|
||||
dout->set_abstract(args_spec_list[list_size - kOne]);
|
||||
(void)grads.emplace_back(dout);
|
||||
auto newcnode = fg->NewCNode(grads);
|
||||
expander::bprop::BpropExpanderInGraphMode be;
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
if (be.Run(newcnode)) {
|
||||
bprop_fg = be.GetGraph();
|
||||
(void)mindspore::opt::ConvertPrimToPrimPy(bprop_fg);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Expander failed. Prim is: " << primal_->name();
|
||||
}
|
||||
return bprop_fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr BpropExpanderMetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) {
|
||||
return BpropExpanderFunc(input_abs);
|
||||
auto fg = NewGraph(input_abs);
|
||||
try {
|
||||
if (!expander::bprop::ExpandBpropInGraphMode(handle_, primal_, fg)) {
|
||||
return nullptr;
|
||||
}
|
||||
} catch (const py::type_error &ex) {
|
||||
MS_EXCEPTION(TypeError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
|
||||
} catch (const py::value_error &ex) {
|
||||
MS_EXCEPTION(ValueError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << e.what() << "]";
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size) {
|
||||
FuncGraphPtr GetExpandBprop(const BpropHandle *handle, const PrimitivePtr &primal, size_t forward_inputs_size) {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
auto meta_graph = std::make_shared<BpropExpanderMetaFuncGraph>(primal);
|
||||
auto meta_graph = std::make_shared<BpropExpanderMetaFuncGraph>(primal, handle);
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)};
|
||||
for (size_t i = 0; i < forward_inputs_size; ++i) {
|
||||
(void)inputs.emplace_back(fg->add_parameter());
|
||||
|
|
|
@ -28,22 +28,28 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace graph_bprop {
|
||||
constexpr int64_t kTwo = 2;
|
||||
constexpr int64_t kOne = 1;
|
||||
using BpropHandle = expander::bprop::BpropHandle;
|
||||
|
||||
class BpropExpanderMetaFuncGraph : public BpropMetaFuncGraph {
|
||||
public:
|
||||
explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal) : BpropMetaFuncGraph(primal->name(), primal) {}
|
||||
explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal, const BpropHandle *handle)
|
||||
: BpropMetaFuncGraph(primal->name(), primal), handle_(handle) {}
|
||||
~BpropExpanderMetaFuncGraph() override = default;
|
||||
MS_DECLARE_PARENT(BpropExpanderMetaFuncGraph, BpropMetaFuncGraph);
|
||||
FuncGraphPtr BpropExpanderFunc(const AbstractBasePtrList &args_spec_list);
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) override;
|
||||
|
||||
private:
|
||||
const BpropHandle *handle_;
|
||||
};
|
||||
|
||||
FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size);
|
||||
FuncGraphPtr GetExpandBprop(const BpropHandle *handle, const PrimitivePtr &primal, size_t forward_inputs_size);
|
||||
|
||||
#define STR(s) #s
|
||||
#define REGISTER_EXPANDER_BPROP_IMPL(name) \
|
||||
static auto helper_expand_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetExpandBprop);
|
||||
#define REGISTER_EXPANDER_BPROP_IMPL(name) \
|
||||
static auto helper_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper( \
|
||||
STR(name), [](const PrimitivePtr &primal, const size_t forward_inputs_size) -> FuncGraphPtr { \
|
||||
static auto *handle = expander::bprop::BpropIRBuilderFactory::Instance().GetBuilder(STR(name)); \
|
||||
return GetExpandBprop(handle, primal, forward_inputs_size); \
|
||||
})
|
||||
|
||||
void RegBpropExpanderOps();
|
||||
} // namespace graph_bprop
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include "mindspore/core/utils/anf_utils.h"
|
||||
#include "frontend/parallel/auto_parallel/costmodel.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
|
@ -57,6 +58,33 @@ const std::map<std::string, std::vector<std::string>> op2attrs = {
|
|||
{prim::kPrimBatchMatMul->name(), {kTransposeA, kTransposeB}}};
|
||||
} // namespace
|
||||
|
||||
ValuePtr ConvertPrimToPrimPy(const PrimitivePtr &primc) {
|
||||
if (primc == nullptr || primc->isa<PrimitivePy>()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (abstract::GetFrontendPrimitiveInferImpl(primc).has_value()) {
|
||||
return nullptr;
|
||||
}
|
||||
if (primc->isa<prim::DoSignaturePrimitive>()) {
|
||||
return nullptr;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
const auto iter = op2attrs.find(primc->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
for (auto &attr : iter->second) {
|
||||
if (primc->HasAttr(attr)) {
|
||||
(void)attrs.emplace_back(std::pair{attr, primc->GetAttr(attr)});
|
||||
} else {
|
||||
MS_LOG(WARNING) << primc->name() << " op do not have attr: " << attr;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto new_prim = parallel::CreateOpInstance(attrs, primc->name(), "");
|
||||
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primc->attrs());
|
||||
return new_prim;
|
||||
}
|
||||
|
||||
class PrimpyConverter {
|
||||
public:
|
||||
bool Run(const FuncGraphPtr &graph) {
|
||||
|
@ -76,29 +104,7 @@ class PrimpyConverter {
|
|||
continue;
|
||||
}
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
|
||||
continue;
|
||||
}
|
||||
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
|
||||
continue;
|
||||
}
|
||||
if (primitive->isa<prim::DoSignaturePrimitive>()) {
|
||||
continue;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
const auto iter = op2attrs.find(primitive->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
for (auto &attr : iter->second) {
|
||||
if (primitive->HasAttr(attr)) {
|
||||
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
|
||||
} else {
|
||||
MS_LOG(WARNING) << primitive->name() << " op do not have attr: " << attr;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
|
||||
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
|
||||
auto new_prim = ConvertPrimToPrimPy(primitive);
|
||||
AnfNodePtrList inputs = {NewValueNode(new_prim)};
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
|
@ -111,6 +117,7 @@ class PrimpyConverter {
|
|||
private:
|
||||
std::set<FuncGraphPtr> visited_graphs_;
|
||||
};
|
||||
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
PrimpyConverter c;
|
||||
return c.Run(graph);
|
||||
|
|
|
@ -17,7 +17,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H
|
||||
|
||||
#include "mindspore/core/base/base.h"
|
||||
#include "ir/anf.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
/**
|
||||
|
@ -25,6 +26,7 @@ namespace opt {
|
|||
*/
|
||||
AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node);
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph);
|
||||
ValuePtr ConvertPrimToPrimPy(const PrimitivePtr &primc);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H
|
||||
|
|
|
@ -45,7 +45,7 @@ class COMMON_EXPORT Emitter {
|
|||
/// \brief Emit a ValueNode
|
||||
NodePtr EmitValue(const ValuePtr &value) const;
|
||||
|
||||
NodePtr MakeTuple(const NodePtrList &inputs) const { return Emit(prim::kMakeTuple, inputs); }
|
||||
NodePtr MakeTuple(const NodePtrList &inputs) const { return EmitOp(prim::kPrimMakeTuple, inputs, {}); }
|
||||
NodePtr MakeList(const NodePtrList &inputs) const { return Emit("make_list", inputs); }
|
||||
NodePtr TupleGetItem(const NodePtr &input, size_t i) const {
|
||||
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
|
||||
|
@ -199,6 +199,7 @@ class COMMON_EXPORT Emitter {
|
|||
NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) const;
|
||||
|
||||
protected:
|
||||
virtual NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs, const DAttr &attrs) const;
|
||||
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
|
||||
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {
|
||||
auto node = UnifyDtypeAndEmit(op, lhs, rhs);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "include/common/expander/core/infer.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "frontend/optimizer/expander.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace expander {
|
||||
|
@ -218,38 +218,6 @@ void BpropExpander::DumpResult(const std::string &name) const {
|
|||
}
|
||||
}
|
||||
|
||||
bool BpropExpanderInGraphMode::Run(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
|
||||
bool ret = true;
|
||||
if (outputs_ != nullptr) {
|
||||
outputs_->clear();
|
||||
}
|
||||
auto node_name = AnfUtils::GetCNodeName(cnode);
|
||||
try {
|
||||
ret = RunBprop(cnode);
|
||||
} catch (const py::type_error &ex) {
|
||||
MS_EXCEPTION(TypeError) << "Bprop \"" << node_name << "\" encounter a problem: [" << ex.what() << "]";
|
||||
} catch (const py::value_error &ex) {
|
||||
MS_EXCEPTION(ValueError) << "Bprop \"" << node_name << "\" encounter a problem: [" << ex.what() << "]";
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
|
||||
}
|
||||
MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
input_nodes_.reserve(cnode->size());
|
||||
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_),
|
||||
[ir_builder, this](const AnfNodePtr &no) {
|
||||
auto p = this->fg_->add_parameter();
|
||||
p->set_abstract(no->abstract());
|
||||
return std::make_shared<Node>(p, ir_builder);
|
||||
});
|
||||
}
|
||||
|
||||
class LazyInfer : public CppInfer {
|
||||
public:
|
||||
void Infer(const NodePtr &node) override { return; }
|
||||
|
@ -276,43 +244,84 @@ class LazyInfer : public CppInfer {
|
|||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<BpropIRBuilder> BpropExpanderInGraphMode::CreateIRBuilder(const std::string &name,
|
||||
const CNodePtr &cnode) {
|
||||
fg_ = std::make_shared<FuncGraph>();
|
||||
ExpanderInferPtr infer;
|
||||
// default use LazyInfer in graph mode.
|
||||
class GraphModeBuilder : public BpropIRBuilder {
|
||||
public:
|
||||
GraphModeBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
|
||||
: BpropIRBuilder(name, func_graph, infer) {}
|
||||
|
||||
NodePtrList Build(const NodePtrList &inputs, const DAttr &attrs, const BpropHandle &handle) {
|
||||
auto outputs = Run(inputs, attrs, handle);
|
||||
auto mt = this->MakeTuple(outputs)->get();
|
||||
func_graph_->set_output(mt);
|
||||
if (has_ctrl_flow_) {
|
||||
// clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
|
||||
auto todos = TopoSort(func_graph_->get_return(), SuccDeeperSimple, AlwaysInclude);
|
||||
for (auto &no : todos) {
|
||||
no->set_abstract(nullptr);
|
||||
if (IsValueNode<FuncGraph>(no)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(no);
|
||||
for (auto &p : fg->parameters()) {
|
||||
p->set_abstract(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
private:
|
||||
NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs, const DAttr &attrs) const {
|
||||
if (prim->name() == "Switch") {
|
||||
has_ctrl_flow_ = true;
|
||||
}
|
||||
auto primpy = opt::ConvertPrimToPrimPy(prim);
|
||||
AnfNodePtrList cnode_inputs = {NewValueNode(primpy ? primpy : prim)};
|
||||
cnode_inputs.reserve(inputs.size() + 1);
|
||||
(void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
|
||||
MS_EXCEPTION_IF_NULL(no);
|
||||
return no->get();
|
||||
});
|
||||
auto cnode = func_graph_->NewCNode(cnode_inputs);
|
||||
if (scope_ != nullptr) {
|
||||
cnode->set_scope(scope_);
|
||||
}
|
||||
auto node = NewNode(cnode->cast<AnfNodePtr>());
|
||||
infer_->Infer(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
mutable bool has_ctrl_flow_{false};
|
||||
};
|
||||
|
||||
bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph) {
|
||||
static bool use_imm_infer = (common::GetEnv("MS_DEV_BPROP_IMM_INFER") == "on");
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
auto name = prim->name();
|
||||
if (handle == nullptr) {
|
||||
MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
|
||||
return false;
|
||||
}
|
||||
ExpanderInferPtr infer;
|
||||
if (use_imm_infer) {
|
||||
infer = std::make_shared<CppInfer>();
|
||||
} else {
|
||||
infer = std::make_shared<LazyInfer>();
|
||||
}
|
||||
return std::make_unique<BpropIRBuilder>(name, fg_, infer);
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::PostProcess() const {
|
||||
auto mt = output_nodes_[0]->emitter()->MakeTuple(output_nodes_)->get();
|
||||
fg_->set_output(mt);
|
||||
|
||||
// clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
|
||||
auto todos = TopoSort(fg_->get_return(), SuccDeeperSimple, AlwaysInclude);
|
||||
for (auto &no : todos) {
|
||||
no->set_abstract(nullptr);
|
||||
if (IsValueNode<FuncGraph>(no)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(no);
|
||||
for (auto &p : fg->parameters()) {
|
||||
p->set_abstract(nullptr);
|
||||
}
|
||||
}
|
||||
GraphModeBuilder ir_builder(name, graph, infer);
|
||||
auto ¶meters = graph->parameters();
|
||||
NodePtrList inputs;
|
||||
inputs.reserve(parameters.size());
|
||||
(void)std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(inputs),
|
||||
[&ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, &ir_builder); });
|
||||
auto outputs = ir_builder.Build(inputs, prim->attrs(), *handle);
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::DumpResult(const std::string &name) const {
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
if (!dump_result) {
|
||||
return;
|
||||
if (dump_result) {
|
||||
DumpIR("bprop/bprop_expander_" + name + ".ir", graph, true);
|
||||
}
|
||||
DumpIR("bprop/bprop_expander_" + name + ".ir", fg_, true);
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
|
|
@ -33,18 +33,18 @@ class BpropExpander {
|
|||
BpropExpander() {}
|
||||
BpropExpander(CNodePtrList *outputs, UserType *users) : outputs_(outputs), users_(users) {}
|
||||
~BpropExpander() = default;
|
||||
virtual bool Run(const CNodePtr &cnode);
|
||||
bool Run(const CNodePtr &cnode);
|
||||
const std::vector<size_t> &GetUnusedInputs(const CNodePtr &cnode) const;
|
||||
|
||||
protected:
|
||||
bool RunBprop(const CNodePtr &cnode);
|
||||
virtual void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
|
||||
virtual std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode);
|
||||
void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
|
||||
std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode);
|
||||
const BpropHandle *GetBpropHandle(const std::string &name) const {
|
||||
return BpropIRBuilderFactory::Instance().GetBuilder(name);
|
||||
}
|
||||
virtual void PostProcess() const;
|
||||
virtual void DumpResult(const std::string &name) const;
|
||||
void PostProcess() const;
|
||||
void DumpResult(const std::string &name) const;
|
||||
NodePtrList input_nodes_;
|
||||
// outputs_ must be CNodePtrList, but output_nodes_ may not necessary. output_nodes_ are used to
|
||||
// create bprop func_graph in graph_mode.
|
||||
|
@ -53,20 +53,7 @@ class BpropExpander {
|
|||
UserType *users_{nullptr};
|
||||
};
|
||||
|
||||
class BpropExpanderInGraphMode : public BpropExpander {
|
||||
public:
|
||||
BpropExpanderInGraphMode() {}
|
||||
~BpropExpanderInGraphMode() = default;
|
||||
bool Run(const CNodePtr &cnode) override;
|
||||
FuncGraphPtr GetGraph() { return fg_; }
|
||||
|
||||
protected:
|
||||
FuncGraphPtr fg_{nullptr};
|
||||
void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) override;
|
||||
std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode) override;
|
||||
void PostProcess() const override;
|
||||
void DumpResult(const std::string &name) const override;
|
||||
};
|
||||
bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph);
|
||||
|
||||
#ifdef _MSC_VER
|
||||
class WinBpropRegister {
|
||||
|
|
Loading…
Reference in New Issue