!49724 optimize bprop expander compile time in graph mode

Merge pull request !49724 from Gaoxiong/master
This commit is contained in:
i-robot 2023-03-06 02:30:39 +00:00 committed by Gitee
commit 390f6e35ba
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 147 additions and 148 deletions

View File

@ -70,7 +70,11 @@ NodePtr Emitter::Emit(const std::string &op_name, const NodePtrList &inputs, con
}
}
}
AnfNodePtrList cnode_inputs = {NewValueNode(primc)};
return EmitOp(primc, inputs, attrs);
}
NodePtr Emitter::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs, const DAttr &attrs) const {
AnfNodePtrList cnode_inputs = {NewValueNode(prim)};
cnode_inputs.reserve(inputs.size() + 1);
(void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
MS_EXCEPTION_IF_NULL(no);

View File

@ -18,46 +18,29 @@
#include "frontend/operator/graph_bprop/ops_utils.h"
#include "include/common/utils/utils.h"
#include "pipeline/pynative/grad/bprop_expander/bprop.h"
#include "include/common/utils/python_adapter.h"
namespace mindspore {
namespace graph_bprop {
FuncGraphPtr BpropExpanderMetaFuncGraph::BpropExpanderFunc(const AbstractBasePtrList &args_spec_list) {
int64_t list_size = SizeToLong(args_spec_list.size());
auto fg = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(primal_));
for (int64_t i = 0; i < list_size - kTwo; ++i) {
auto abs_i = args_spec_list[i];
auto x = fg->add_parameter();
x->set_abstract(args_spec_list[i]);
x->abstract()->set_value(args_spec_list[i]->BuildValue());
(void)grads.emplace_back(x);
}
auto out = fg->add_parameter();
out->set_abstract(args_spec_list[list_size - kTwo]);
(void)grads.emplace_back(out);
auto dout = fg->add_parameter();
dout->set_abstract(args_spec_list[list_size - kOne]);
(void)grads.emplace_back(dout);
auto newcnode = fg->NewCNode(grads);
expander::bprop::BpropExpanderInGraphMode be;
FuncGraphPtr bprop_fg = nullptr;
if (be.Run(newcnode)) {
bprop_fg = be.GetGraph();
(void)mindspore::opt::ConvertPrimToPrimPy(bprop_fg);
} else {
MS_LOG(EXCEPTION) << "Expander failed. Prim is: " << primal_->name();
}
return bprop_fg;
}
FuncGraphPtr BpropExpanderMetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) {
return BpropExpanderFunc(input_abs);
auto fg = NewGraph(input_abs);
try {
if (!expander::bprop::ExpandBpropInGraphMode(handle_, primal_, fg)) {
return nullptr;
}
} catch (const py::type_error &ex) {
MS_EXCEPTION(TypeError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
} catch (const py::value_error &ex) {
MS_EXCEPTION(ValueError) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << ex.what() << "]";
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Bprop \"" << primal_->name() << "\" encounter a problem: [" << e.what() << "]";
}
return fg;
}
FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size) {
FuncGraphPtr GetExpandBprop(const BpropHandle *handle, const PrimitivePtr &primal, size_t forward_inputs_size) {
auto fg = std::make_shared<FuncGraph>();
auto meta_graph = std::make_shared<BpropExpanderMetaFuncGraph>(primal);
auto meta_graph = std::make_shared<BpropExpanderMetaFuncGraph>(primal, handle);
std::vector<AnfNodePtr> inputs{NewValueNode(meta_graph)};
for (size_t i = 0; i < forward_inputs_size; ++i) {
(void)inputs.emplace_back(fg->add_parameter());

View File

@ -28,22 +28,28 @@
namespace mindspore {
namespace graph_bprop {
constexpr int64_t kTwo = 2;
constexpr int64_t kOne = 1;
using BpropHandle = expander::bprop::BpropHandle;
class BpropExpanderMetaFuncGraph : public BpropMetaFuncGraph {
public:
explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal) : BpropMetaFuncGraph(primal->name(), primal) {}
explicit BpropExpanderMetaFuncGraph(const PrimitivePtr &primal, const BpropHandle *handle)
: BpropMetaFuncGraph(primal->name(), primal), handle_(handle) {}
~BpropExpanderMetaFuncGraph() override = default;
MS_DECLARE_PARENT(BpropExpanderMetaFuncGraph, BpropMetaFuncGraph);
FuncGraphPtr BpropExpanderFunc(const AbstractBasePtrList &args_spec_list);
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &input_abs) override;
private:
const BpropHandle *handle_;
};
FuncGraphPtr GetExpandBprop(const PrimitivePtr &primal, const size_t &forward_inputs_size);
FuncGraphPtr GetExpandBprop(const BpropHandle *handle, const PrimitivePtr &primal, size_t forward_inputs_size);
#define STR(s) #s
#define REGISTER_EXPANDER_BPROP_IMPL(name) \
static auto helper_expand_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper(STR(name), GetExpandBprop);
#define REGISTER_EXPANDER_BPROP_IMPL(name) \
static auto helper_bprop_##name = graph_bprop::RegisterPrimitiveBpropHelper( \
STR(name), [](const PrimitivePtr &primal, const size_t forward_inputs_size) -> FuncGraphPtr { \
static auto *handle = expander::bprop::BpropIRBuilderFactory::Instance().GetBuilder(STR(name)); \
return GetExpandBprop(handle, primal, forward_inputs_size); \
})
void RegBpropExpanderOps();
} // namespace graph_bprop

View File

@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include <map>
#include <set>
#include "mindspore/core/utils/anf_utils.h"
#include "frontend/parallel/auto_parallel/costmodel.h"
#include "frontend/parallel/graph_util/generate_graph.h"
@ -57,6 +58,33 @@ const std::map<std::string, std::vector<std::string>> op2attrs = {
{prim::kPrimBatchMatMul->name(), {kTransposeA, kTransposeB}}};
} // namespace
ValuePtr ConvertPrimToPrimPy(const PrimitivePtr &primc) {
if (primc == nullptr || primc->isa<PrimitivePy>()) {
return nullptr;
}
if (abstract::GetFrontendPrimitiveInferImpl(primc).has_value()) {
return nullptr;
}
if (primc->isa<prim::DoSignaturePrimitive>()) {
return nullptr;
}
parallel::OperatorAttrs attrs;
const auto iter = op2attrs.find(primc->name());
if (iter != op2attrs.end()) {
for (auto &attr : iter->second) {
if (primc->HasAttr(attr)) {
(void)attrs.emplace_back(std::pair{attr, primc->GetAttr(attr)});
} else {
MS_LOG(WARNING) << primc->name() << " op do not have attr: " << attr;
return nullptr;
}
}
}
auto new_prim = parallel::CreateOpInstance(attrs, primc->name(), "");
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primc->attrs());
return new_prim;
}
class PrimpyConverter {
public:
bool Run(const FuncGraphPtr &graph) {
@ -76,29 +104,7 @@ class PrimpyConverter {
continue;
}
auto primitive = GetCNodePrimitive(node);
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
continue;
}
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
continue;
}
if (primitive->isa<prim::DoSignaturePrimitive>()) {
continue;
}
parallel::OperatorAttrs attrs;
const auto iter = op2attrs.find(primitive->name());
if (iter != op2attrs.end()) {
for (auto &attr : iter->second) {
if (primitive->HasAttr(attr)) {
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
} else {
MS_LOG(WARNING) << primitive->name() << " op do not have attr: " << attr;
return false;
}
}
}
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
auto new_prim = ConvertPrimToPrimPy(primitive);
AnfNodePtrList inputs = {NewValueNode(new_prim)};
auto cnode = dyn_cast_ptr<CNode>(node);
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
@ -111,6 +117,7 @@ class PrimpyConverter {
private:
std::set<FuncGraphPtr> visited_graphs_;
};
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
PrimpyConverter c;
return c.Run(graph);

View File

@ -17,7 +17,8 @@
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H
#include "mindspore/core/base/base.h"
#include "ir/anf.h"
namespace mindspore {
namespace opt {
/**
@ -25,6 +26,7 @@ namespace opt {
*/
AnfNodePtr TryExpandCNodeFE(const AnfNodePtr &node);
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph);
ValuePtr ConvertPrimToPrimPy(const PrimitivePtr &primc);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_EXPANDER_H

View File

@ -45,7 +45,7 @@ class COMMON_EXPORT Emitter {
/// \brief Emit a ValueNode
NodePtr EmitValue(const ValuePtr &value) const;
NodePtr MakeTuple(const NodePtrList &inputs) const { return Emit(prim::kMakeTuple, inputs); }
NodePtr MakeTuple(const NodePtrList &inputs) const { return EmitOp(prim::kPrimMakeTuple, inputs, {}); }
NodePtr MakeList(const NodePtrList &inputs) const { return Emit("make_list", inputs); }
NodePtr TupleGetItem(const NodePtr &input, size_t i) const {
return Emit(prim::kTupleGetItem, {input, Value(static_cast<int64_t>(i))});
@ -199,6 +199,7 @@ class COMMON_EXPORT Emitter {
NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) const;
protected:
virtual NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs, const DAttr &attrs) const;
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {
auto node = UnifyDtypeAndEmit(op, lhs, rhs);

View File

@ -21,7 +21,7 @@
#include "include/common/expander/core/infer.h"
#include "utils/anf_utils.h"
#include "include/common/debug/anf_ir_dump.h"
#include "include/common/utils/python_adapter.h"
#include "frontend/optimizer/expander.h"
namespace mindspore {
namespace expander {
@ -218,38 +218,6 @@ void BpropExpander::DumpResult(const std::string &name) const {
}
}
bool BpropExpanderInGraphMode::Run(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
bool ret = true;
if (outputs_ != nullptr) {
outputs_->clear();
}
auto node_name = AnfUtils::GetCNodeName(cnode);
try {
ret = RunBprop(cnode);
} catch (const py::type_error &ex) {
MS_EXCEPTION(TypeError) << "Bprop \"" << node_name << "\" encounter a problem: [" << ex.what() << "]";
} catch (const py::value_error &ex) {
MS_EXCEPTION(ValueError) << "Bprop \"" << node_name << "\" encounter a problem: [" << ex.what() << "]";
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
}
MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
return ret;
}
void BpropExpanderInGraphMode::ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
input_nodes_.reserve(cnode->size());
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(input_nodes_),
[ir_builder, this](const AnfNodePtr &no) {
auto p = this->fg_->add_parameter();
p->set_abstract(no->abstract());
return std::make_shared<Node>(p, ir_builder);
});
}
class LazyInfer : public CppInfer {
public:
void Infer(const NodePtr &node) override { return; }
@ -276,43 +244,84 @@ class LazyInfer : public CppInfer {
}
};
std::unique_ptr<BpropIRBuilder> BpropExpanderInGraphMode::CreateIRBuilder(const std::string &name,
const CNodePtr &cnode) {
fg_ = std::make_shared<FuncGraph>();
ExpanderInferPtr infer;
// default use LazyInfer in graph mode.
class GraphModeBuilder : public BpropIRBuilder {
public:
GraphModeBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
: BpropIRBuilder(name, func_graph, infer) {}
NodePtrList Build(const NodePtrList &inputs, const DAttr &attrs, const BpropHandle &handle) {
auto outputs = Run(inputs, attrs, handle);
auto mt = this->MakeTuple(outputs)->get();
func_graph_->set_output(mt);
if (has_ctrl_flow_) {
// clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
auto todos = TopoSort(func_graph_->get_return(), SuccDeeperSimple, AlwaysInclude);
for (auto &no : todos) {
no->set_abstract(nullptr);
if (IsValueNode<FuncGraph>(no)) {
auto fg = GetValueNode<FuncGraphPtr>(no);
for (auto &p : fg->parameters()) {
p->set_abstract(nullptr);
}
}
}
}
return outputs;
}
private:
NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs, const DAttr &attrs) const {
if (prim->name() == "Switch") {
has_ctrl_flow_ = true;
}
auto primpy = opt::ConvertPrimToPrimPy(prim);
AnfNodePtrList cnode_inputs = {NewValueNode(primpy ? primpy : prim)};
cnode_inputs.reserve(inputs.size() + 1);
(void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
MS_EXCEPTION_IF_NULL(no);
return no->get();
});
auto cnode = func_graph_->NewCNode(cnode_inputs);
if (scope_ != nullptr) {
cnode->set_scope(scope_);
}
auto node = NewNode(cnode->cast<AnfNodePtr>());
infer_->Infer(node);
return node;
}
mutable bool has_ctrl_flow_{false};
};
bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph) {
static bool use_imm_infer = (common::GetEnv("MS_DEV_BPROP_IMM_INFER") == "on");
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
auto name = prim->name();
if (handle == nullptr) {
MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
return false;
}
ExpanderInferPtr infer;
if (use_imm_infer) {
infer = std::make_shared<CppInfer>();
} else {
infer = std::make_shared<LazyInfer>();
}
return std::make_unique<BpropIRBuilder>(name, fg_, infer);
}
void BpropExpanderInGraphMode::PostProcess() const {
auto mt = output_nodes_[0]->emitter()->MakeTuple(output_nodes_)->get();
fg_->set_output(mt);
// clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
auto todos = TopoSort(fg_->get_return(), SuccDeeperSimple, AlwaysInclude);
for (auto &no : todos) {
no->set_abstract(nullptr);
if (IsValueNode<FuncGraph>(no)) {
auto fg = GetValueNode<FuncGraphPtr>(no);
for (auto &p : fg->parameters()) {
p->set_abstract(nullptr);
}
}
GraphModeBuilder ir_builder(name, graph, infer);
auto &parameters = graph->parameters();
NodePtrList inputs;
inputs.reserve(parameters.size());
(void)std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(inputs),
[&ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, &ir_builder); });
auto outputs = ir_builder.Build(inputs, prim->attrs(), *handle);
if (outputs.empty()) {
MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
return false;
}
}
void BpropExpanderInGraphMode::DumpResult(const std::string &name) const {
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
if (!dump_result) {
return;
if (dump_result) {
DumpIR("bprop/bprop_expander_" + name + ".ir", graph, true);
}
DumpIR("bprop/bprop_expander_" + name + ".ir", fg_, true);
return true;
}
#ifdef _MSC_VER

View File

@ -33,18 +33,18 @@ class BpropExpander {
BpropExpander() {}
BpropExpander(CNodePtrList *outputs, UserType *users) : outputs_(outputs), users_(users) {}
~BpropExpander() = default;
virtual bool Run(const CNodePtr &cnode);
bool Run(const CNodePtr &cnode);
const std::vector<size_t> &GetUnusedInputs(const CNodePtr &cnode) const;
protected:
bool RunBprop(const CNodePtr &cnode);
virtual void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
virtual std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode);
void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder);
std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode);
const BpropHandle *GetBpropHandle(const std::string &name) const {
return BpropIRBuilderFactory::Instance().GetBuilder(name);
}
virtual void PostProcess() const;
virtual void DumpResult(const std::string &name) const;
void PostProcess() const;
void DumpResult(const std::string &name) const;
NodePtrList input_nodes_;
// outputs_ must be CNodePtrList, but output_nodes_ may not necessary. output_nodes_ are used to
// create bprop func_graph in graph_mode.
@ -53,20 +53,7 @@ class BpropExpander {
UserType *users_{nullptr};
};
class BpropExpanderInGraphMode : public BpropExpander {
public:
BpropExpanderInGraphMode() {}
~BpropExpanderInGraphMode() = default;
bool Run(const CNodePtr &cnode) override;
FuncGraphPtr GetGraph() { return fg_; }
protected:
FuncGraphPtr fg_{nullptr};
void ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) override;
std::unique_ptr<BpropIRBuilder> CreateIRBuilder(const std::string &name, const CNodePtr &cnode) override;
void PostProcess() const override;
void DumpResult(const std::string &name) const override;
};
bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph);
#ifdef _MSC_VER
class WinBpropRegister {