expander supports control-flow. (if-else and while-loop)

This commit is contained in:
dayschan 2023-01-31 16:44:08 +08:00
parent 360add4014
commit 451f8b5756
5 changed files with 296 additions and 39 deletions

View File

@ -43,45 +43,63 @@ const std::map<std::string, std::vector<std::string>> op2attrs = {
{prim::kPrimMatMul->name(), {kTransposeA, kTransposeB}}};
}
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto todos = TopoSort(graph->get_return());
auto mng = Manage({graph}, false);
for (const auto &node : todos) {
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto primitive = GetCNodePrimitive(node);
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
continue;
}
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
continue;
}
if (primitive->isa<prim::DoSignaturePrimitive>()) {
continue;
}
parallel::OperatorAttrs attrs;
const auto iter = op2attrs.find(primitive->name());
if (iter != op2attrs.end()) {
for (auto &attr : iter->second) {
if (primitive->HasAttr(attr)) {
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
} else {
MS_LOG(WARNING) << primitive->name() << " op do not have attr: " << attr;
return false;
class PrimpyConverter {
public:
bool Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
(void)visited_graphs_.insert(graph);
auto todos = TopoSort(graph->get_return());
auto mng = Manage({graph}, false);
for (const auto &node : todos) {
if (node->isa<ValueNode>()) {
auto sub_graph = node->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
if (sub_graph != nullptr && visited_graphs_.count(sub_graph) == 0) {
(void)Run(sub_graph);
continue;
}
}
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto primitive = GetCNodePrimitive(node);
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
continue;
}
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
continue;
}
if (primitive->isa<prim::DoSignaturePrimitive>()) {
continue;
}
parallel::OperatorAttrs attrs;
const auto iter = op2attrs.find(primitive->name());
if (iter != op2attrs.end()) {
for (auto &attr : iter->second) {
if (primitive->HasAttr(attr)) {
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
} else {
MS_LOG(WARNING) << primitive->name() << " op do not have attr: " << attr;
return false;
}
}
}
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
AnfNodePtrList inputs = {NewValueNode(new_prim)};
auto cnode = dyn_cast_ptr<CNode>(node);
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
auto new_cnode = graph->NewCNodeInOrder(inputs);
(void)mng->Replace(node, new_cnode);
}
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
AnfNodePtrList inputs = {NewValueNode(new_prim)};
auto cnode = dyn_cast_ptr<CNode>(node);
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
auto new_cnode = graph->NewCNodeInOrder(inputs);
(void)mng->Replace(node, new_cnode);
return true;
}
return true;
private:
std::set<FuncGraphPtr> visited_graphs_;
};
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
PrimpyConverter c;
return c.Run(graph);
}
#ifdef ENABLE_AKG

View File

@ -247,6 +247,18 @@ void BpropExpanderInGraphMode::PostProcess() const {
auto mt = fg_->NewCNode(new_outputs);
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
fg_->set_output(mt);
// clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
auto todos = TopoSort(fg_->get_return(), SuccDeeperSimple, AlwaysInclude);
for (auto &no : todos) {
no->set_abstract(nullptr);
if (IsValueNode<FuncGraph>(no)) {
auto fg = GetValueNode<FuncGraphPtr>(no);
for (auto &p : fg->parameters()) {
p->set_abstract(nullptr);
}
}
}
}
void BpropExpanderInGraphMode::DumpResult(const std::string &name) const {

View File

@ -392,6 +392,208 @@ std::tuple<NodePtr, NodePtr> Emitter::UnifyDtype2(const NodePtr &lhs, const Node
return {lhs, this->Cast(rhs, lhs->dtype())};
}
class Emitter::CtrlFlowBlock {
public:
explicit CtrlFlowBlock(const Emitter *emitter) : emitter_(emitter) { MS_EXCEPTION_IF_NULL(emitter); }
~CtrlFlowBlock() = default;
NodePtr IfThenElse(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
auto tb = BuildSubgraph(true_case);
auto fb = BuildSubgraph(false_case);
auto s = emitter_->Emit("Switch", {cond, tb, fb});
auto cnode = emitter_->func_graph_->NewCNode({s->get()});
cnode->set_abstract(out_abstract_);
auto node = emitter_->NewNode(cnode->cast<AnfNodePtr>());
return node;
}
NodePtr While(const NodePtr &cond, const BlockFunc &while_body_func, const NodePtrList &init_list) {
auto while_fg = std::make_shared<FuncGraph>();
MS_EXCEPTION_IF_NULL(while_fg);
auto cond_cnode = cond->get<CNodePtr>();
MS_EXCEPTION_IF_NULL(cond_cnode);
cond_cnode->set_func_graph(while_fg);
auto while_fg_emitter = std::make_unique<Emitter>(while_fg, std::make_shared<CppInferWithPartial>());
AnfNodePtrList main_while_fg_inputs = {NewValueNode(while_fg)};
std::map<AnfNodePtr, ParameterPtr> param_map;
auto replace_by_param = [&main_while_fg_inputs, &param_map, &while_fg](const AnfNodePtr &inp) {
auto &param = param_map[inp];
if (param == nullptr) {
param = while_fg->add_parameter();
param->set_abstract(inp->abstract());
(void)main_while_fg_inputs.emplace_back(inp);
}
return param;
};
auto empty_body_func = [&init_list](const Emitter *e) { return init_list; };
auto empty_body_fg_with_inputs = BuildSubgraphOfPartial(empty_body_func);
for (size_t i = 1; i < empty_body_fg_with_inputs.size(); i++) {
auto inp = empty_body_fg_with_inputs[i]->get();
empty_body_fg_with_inputs[i] = while_fg_emitter->NewNode(replace_by_param(inp));
}
for (size_t i = 1; i < cond_cnode->size(); i++) {
auto inp = cond_cnode->input(i);
if (!inp->isa<ValueNode>()) {
cond_cnode->set_input(i, replace_by_param(inp));
}
}
auto body_with_inputs = BuildSubgraphOfPartial(while_body_func);
auto body_fg = body_with_inputs[0]->get<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
for (size_t i = 1; i < body_with_inputs.size(); i++) {
body_with_inputs[i] = while_fg_emitter->NewNode(replace_by_param(body_with_inputs[i]->get()));
}
// replace the body's output to call the outside while-fg
AnfNodePtrList body_while_fg_inputs{NewValueNode(while_fg)};
if (IsPrimitiveCNode(body_fg->output(), prim::kPrimMakeTuple)) {
auto mt = body_fg->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(mt);
(void)body_while_fg_inputs.insert(body_while_fg_inputs.end(), mt->inputs().begin() + 1, mt->inputs().end());
} else {
body_while_fg_inputs.push_back(body_fg->output());
}
if (body_while_fg_inputs.size() - 1 != init_list.size()) {
MS_LOG(EXCEPTION) << "The while body's output size should be equal to init_list.size(), but got "
<< (body_while_fg_inputs.size() - 1) << " vs " << init_list.size();
}
if (body_while_fg_inputs.size() < main_while_fg_inputs.size()) {
for (size_t i = body_while_fg_inputs.size(); i < main_while_fg_inputs.size(); i++) {
auto inp = while_fg->parameters()[i - 1];
auto iter = std::find_if(body_with_inputs.begin(), body_with_inputs.end(),
[&inp](const NodePtr &no) { return no->get() == inp; });
if (iter != body_with_inputs.end()) {
auto param_idx = iter - body_with_inputs.begin() - 1;
body_while_fg_inputs.push_back(body_fg->parameters()[param_idx]);
} else {
body_with_inputs.push_back(while_fg_emitter->NewNode(inp));
auto p = body_fg->add_parameter();
p->set_abstract(inp->abstract());
body_while_fg_inputs.push_back(p);
}
}
}
auto body_call_fg = body_fg->NewCNode(body_while_fg_inputs);
body_call_fg->set_abstract(out_abstract_);
body_fg->set_output(body_call_fg);
auto tb = while_fg_emitter->Emit("Partial", body_with_inputs);
auto fb = while_fg_emitter->Emit("Partial", empty_body_fg_with_inputs);
auto s = while_fg_emitter->Emit("Switch", {cond, tb, fb});
auto cnode = while_fg_emitter->func_graph_->NewCNode({s->get()});
cnode->set_abstract(out_abstract_);
while_fg->set_output(cnode);
auto main_cnode = emitter_->func_graph_->NewCNode(main_while_fg_inputs);
main_cnode->set_abstract(out_abstract_);
return emitter_->NewNode(main_cnode);
}
protected:
NodePtr BuildSubgraph(const BlockFunc &func) {
auto fg = std::make_shared<FuncGraph>();
MS_EXCEPTION_IF_NULL(fg);
fg->set_switch_input(std::make_shared<bool>(true));
auto e = std::make_unique<Emitter>(fg, emitter_->infer());
auto output = func(e.get());
if (output.empty()) {
MS_LOG(EXCEPTION) << "The block function should not return empty list.";
}
if (output_num_ == 0) {
output_num_ = output.size();
} else if (output_num_ != output.size()) {
MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
<< " vs " << output.size() << ".";
}
if (output_num_ > 1) {
auto mt = e->MakeTuple(output)->get();
fg->set_output(mt);
SetSequenceNodeElementsUseFlags(mt, std::make_shared<std::vector<bool>>(output_num_, true));
} else {
fg->set_output(output[0]->get());
}
if (out_abstract_ == nullptr) {
out_abstract_ = fg->output()->abstract();
}
return emitter_->Value(fg);
}
NodePtrList BuildSubgraphOfPartial(const BlockFunc &func) {
auto fg = std::make_shared<FuncGraph>();
MS_EXCEPTION_IF_NULL(fg);
fg->set_switch_input(std::make_shared<bool>(true));
auto sub_emitter = std::make_unique<Emitter>(fg, emitter_->infer());
auto output = func(sub_emitter.get());
if (output.empty()) {
MS_LOG(EXCEPTION) << "The block function should not return empty list.";
}
if (output_num_ == 0) {
output_num_ = output.size();
} else if (output_num_ != output.size()) {
MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
<< " vs " << output.size() << ".";
}
fg->set_output((output_num_ > 1) ? sub_emitter->MakeTuple(output)->get() : output[0]->get());
if (out_abstract_ == nullptr) {
out_abstract_ = fg->output()->abstract();
}
if (output_num_ > 1) {
SetSequenceNodeElementsUseFlags(fg->output(), std::make_shared<std::vector<bool>>(output_num_, true));
}
// replace the captured inputs to parameter
std::function<void(const CNodePtr &)> dfs;
std::unordered_set<AnfNodePtr> visited;
std::map<AnfNodePtr, ParameterPtr> param_map;
NodePtrList fg_with_inputs = {emitter_->Value(fg)};
dfs = [&visited, &dfs, &fg, &param_map, &fg_with_inputs, this](const CNodePtr &node) {
(void)visited.insert(node);
for (size_t i = 0; i < node->size(); i++) {
auto inp = node->input(i);
if (inp->func_graph() == nullptr) {
continue;
}
if (inp->func_graph() == fg) {
if (inp->isa<CNode>() && visited.count(inp) == 0) {
dfs(inp->cast<CNodePtr>());
}
} else {
auto &param = param_map[inp];
if (param == nullptr) {
param = fg->add_parameter();
param->set_abstract(inp->abstract());
(void)fg_with_inputs.emplace_back(emitter_->NewNode(inp));
}
node->set_input(i, param);
}
}
};
dfs(fg->get_return());
return fg_with_inputs;
}
size_t output_num_{0};
const Emitter *emitter_;
abstract::AbstractBasePtr out_abstract_{nullptr};
class CppInferWithPartial : public CppInfer {
public:
void Infer(const NodePtr &node) override {
if (IsPrimitiveCNode(node->get(), prim::kPrimPartial) || IsPrimitiveCNode(node->get(), prim::kPrimSwitch)) {
return;
}
CppInfer::Infer(node);
}
};
};
NodePtr Emitter::Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) const {
return CtrlFlowBlock(this).IfThenElse(cond, true_case, false_case);
}
NodePtr Emitter::While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) const {
return CtrlFlowBlock(this).While(cond, body, init_list);
}
NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Add(lhs, rhs); }
NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Sub(lhs, rhs); }
NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Mul(lhs, rhs); }

View File

@ -35,6 +35,7 @@ class MS_CORE_API Emitter {
public:
Emitter(const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr)
: func_graph_(func_graph), infer_(infer), scope_(scope) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(infer);
}
@ -153,6 +154,7 @@ class MS_CORE_API Emitter {
return EmitValue(tensor_ptr);
}
/// \brief get the ExpanderInferPtr
ExpanderInferPtr infer() const { return infer_; }
/// \brief Shape calculation.
@ -168,6 +170,29 @@ class MS_CORE_API Emitter {
NodePtrList ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &shape_func, const ops::InferFunc &infer_func,
const std::vector<int64_t> &value_depend_indices = {}) const;
using BlockFunc = std::function<NodePtrList(const Emitter *)>;
/// \brief Generate a conditional block.
///
/// \param[in] cond condition node, it should be a tensor of Bool.
/// \param[in] true_case the true branch.
/// \param[in] false_case the false branch.
/// \return node of tuple or single value, which is depends on the output list of two branches.
/// \note The overloaded operators (like a+b) should not be used for captured variables in the true_case/false_case
/// functions, use the function argument `Emitter` instead, like `emitter->Add(a, b)`. The output list of two branches
/// should match the join rules of control flow.
NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) const;
/// \brief Generate a while-loop block.
///
/// \param[in] cond condition node, it should be a tensor of Bool.
/// \param[in] body the loop body.
/// \param[in] init_list the initial variables that would be modified in body.
/// \return node of tuple or single value, which is depends on the init_list.
/// \note The overloaded operators (like `a+b`) should not be used for captured variables in the body function, use
/// the function argument `Emitter` instead, like `emitter->Add(a, b)`. The length and node order of the output list
/// of the body function should match init_list.
NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) const;
protected:
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {
@ -180,6 +205,8 @@ class MS_CORE_API Emitter {
return Emit(op, {lhs, rhs}, attrs);
}
class CtrlFlowBlock;
FuncGraphPtr func_graph_;
ExpanderInferPtr infer_{nullptr};
ScopePtr scope_{nullptr};

View File

@ -84,10 +84,8 @@ void CppInfer::Infer(const NodePtr &node) {
});
AbstractBasePtr result = nullptr;
auto found = abstract::GetPrimitiveInferImpl(prim);
if (found.has_value()) {
auto infer = found.value();
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
result = infer.InferShapeAndType(nullptr, prim, abs_list);
if (found.has_value() && found.value().IsImplInferShapeAndType()) {
result = found.value().InferShapeAndType(nullptr, prim, abs_list);
} else {
auto iter = unreg_infer_map.find(prim);
if (iter != unreg_infer_map.end()) {