forked from mindspore-Ecosystem/mindspore
expander supports control-flow. (if-else and while-loop)
This commit is contained in:
parent
360add4014
commit
451f8b5756
|
@ -43,45 +43,63 @@ const std::map<std::string, std::vector<std::string>> op2attrs = {
|
|||
{prim::kPrimMatMul->name(), {kTransposeA, kTransposeB}}};
|
||||
}
|
||||
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto todos = TopoSort(graph->get_return());
|
||||
auto mng = Manage({graph}, false);
|
||||
for (const auto &node : todos) {
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
|
||||
continue;
|
||||
}
|
||||
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
|
||||
continue;
|
||||
}
|
||||
if (primitive->isa<prim::DoSignaturePrimitive>()) {
|
||||
continue;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
const auto iter = op2attrs.find(primitive->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
for (auto &attr : iter->second) {
|
||||
if (primitive->HasAttr(attr)) {
|
||||
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
|
||||
} else {
|
||||
MS_LOG(WARNING) << primitive->name() << " op do not have attr: " << attr;
|
||||
return false;
|
||||
class PrimpyConverter {
|
||||
public:
|
||||
bool Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
(void)visited_graphs_.insert(graph);
|
||||
auto todos = TopoSort(graph->get_return());
|
||||
auto mng = Manage({graph}, false);
|
||||
for (const auto &node : todos) {
|
||||
if (node->isa<ValueNode>()) {
|
||||
auto sub_graph = node->cast<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
|
||||
if (sub_graph != nullptr && visited_graphs_.count(sub_graph) == 0) {
|
||||
(void)Run(sub_graph);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
|
||||
continue;
|
||||
}
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive == nullptr || primitive->isa<PrimitivePy>()) {
|
||||
continue;
|
||||
}
|
||||
if (abstract::GetFrontendPrimitiveInferImpl(primitive).has_value()) {
|
||||
continue;
|
||||
}
|
||||
if (primitive->isa<prim::DoSignaturePrimitive>()) {
|
||||
continue;
|
||||
}
|
||||
parallel::OperatorAttrs attrs;
|
||||
const auto iter = op2attrs.find(primitive->name());
|
||||
if (iter != op2attrs.end()) {
|
||||
for (auto &attr : iter->second) {
|
||||
if (primitive->HasAttr(attr)) {
|
||||
(void)attrs.emplace_back(std::pair{attr, primitive->GetAttr(attr)});
|
||||
} else {
|
||||
MS_LOG(WARNING) << primitive->name() << " op do not have attr: " << attr;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
|
||||
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
|
||||
AnfNodePtrList inputs = {NewValueNode(new_prim)};
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
auto new_cnode = graph->NewCNodeInOrder(inputs);
|
||||
(void)mng->Replace(node, new_cnode);
|
||||
}
|
||||
auto new_prim = parallel::CreateOpInstance(attrs, primitive->name(), "");
|
||||
(void)new_prim->cast_ptr<Primitive>()->SetAttrs(primitive->attrs());
|
||||
AnfNodePtrList inputs = {NewValueNode(new_prim)};
|
||||
auto cnode = dyn_cast_ptr<CNode>(node);
|
||||
(void)inputs.insert(inputs.cend(), cnode->inputs().cbegin() + 1, cnode->inputs().cend());
|
||||
auto new_cnode = graph->NewCNodeInOrder(inputs);
|
||||
(void)mng->Replace(node, new_cnode);
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
|
||||
private:
|
||||
std::set<FuncGraphPtr> visited_graphs_;
|
||||
};
|
||||
bool ConvertPrimToPrimPy(const FuncGraphPtr &graph) {
|
||||
PrimpyConverter c;
|
||||
return c.Run(graph);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_AKG
|
||||
|
|
|
@ -247,6 +247,18 @@ void BpropExpanderInGraphMode::PostProcess() const {
|
|||
auto mt = fg_->NewCNode(new_outputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
fg_->set_output(mt);
|
||||
|
||||
// clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
|
||||
auto todos = TopoSort(fg_->get_return(), SuccDeeperSimple, AlwaysInclude);
|
||||
for (auto &no : todos) {
|
||||
no->set_abstract(nullptr);
|
||||
if (IsValueNode<FuncGraph>(no)) {
|
||||
auto fg = GetValueNode<FuncGraphPtr>(no);
|
||||
for (auto &p : fg->parameters()) {
|
||||
p->set_abstract(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BpropExpanderInGraphMode::DumpResult(const std::string &name) const {
|
||||
|
|
|
@ -392,6 +392,208 @@ std::tuple<NodePtr, NodePtr> Emitter::UnifyDtype2(const NodePtr &lhs, const Node
|
|||
return {lhs, this->Cast(rhs, lhs->dtype())};
|
||||
}
|
||||
|
||||
class Emitter::CtrlFlowBlock {
|
||||
public:
|
||||
explicit CtrlFlowBlock(const Emitter *emitter) : emitter_(emitter) { MS_EXCEPTION_IF_NULL(emitter); }
|
||||
~CtrlFlowBlock() = default;
|
||||
NodePtr IfThenElse(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
|
||||
auto tb = BuildSubgraph(true_case);
|
||||
auto fb = BuildSubgraph(false_case);
|
||||
auto s = emitter_->Emit("Switch", {cond, tb, fb});
|
||||
auto cnode = emitter_->func_graph_->NewCNode({s->get()});
|
||||
cnode->set_abstract(out_abstract_);
|
||||
auto node = emitter_->NewNode(cnode->cast<AnfNodePtr>());
|
||||
return node;
|
||||
}
|
||||
|
||||
NodePtr While(const NodePtr &cond, const BlockFunc &while_body_func, const NodePtrList &init_list) {
|
||||
auto while_fg = std::make_shared<FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(while_fg);
|
||||
auto cond_cnode = cond->get<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cond_cnode);
|
||||
cond_cnode->set_func_graph(while_fg);
|
||||
auto while_fg_emitter = std::make_unique<Emitter>(while_fg, std::make_shared<CppInferWithPartial>());
|
||||
AnfNodePtrList main_while_fg_inputs = {NewValueNode(while_fg)};
|
||||
std::map<AnfNodePtr, ParameterPtr> param_map;
|
||||
auto replace_by_param = [&main_while_fg_inputs, ¶m_map, &while_fg](const AnfNodePtr &inp) {
|
||||
auto ¶m = param_map[inp];
|
||||
if (param == nullptr) {
|
||||
param = while_fg->add_parameter();
|
||||
param->set_abstract(inp->abstract());
|
||||
(void)main_while_fg_inputs.emplace_back(inp);
|
||||
}
|
||||
return param;
|
||||
};
|
||||
|
||||
auto empty_body_func = [&init_list](const Emitter *e) { return init_list; };
|
||||
auto empty_body_fg_with_inputs = BuildSubgraphOfPartial(empty_body_func);
|
||||
for (size_t i = 1; i < empty_body_fg_with_inputs.size(); i++) {
|
||||
auto inp = empty_body_fg_with_inputs[i]->get();
|
||||
empty_body_fg_with_inputs[i] = while_fg_emitter->NewNode(replace_by_param(inp));
|
||||
}
|
||||
for (size_t i = 1; i < cond_cnode->size(); i++) {
|
||||
auto inp = cond_cnode->input(i);
|
||||
if (!inp->isa<ValueNode>()) {
|
||||
cond_cnode->set_input(i, replace_by_param(inp));
|
||||
}
|
||||
}
|
||||
|
||||
auto body_with_inputs = BuildSubgraphOfPartial(while_body_func);
|
||||
auto body_fg = body_with_inputs[0]->get<ValueNodePtr>()->value()->cast<FuncGraphPtr>();
|
||||
for (size_t i = 1; i < body_with_inputs.size(); i++) {
|
||||
body_with_inputs[i] = while_fg_emitter->NewNode(replace_by_param(body_with_inputs[i]->get()));
|
||||
}
|
||||
// replace the body's output to call the outside while-fg
|
||||
AnfNodePtrList body_while_fg_inputs{NewValueNode(while_fg)};
|
||||
if (IsPrimitiveCNode(body_fg->output(), prim::kPrimMakeTuple)) {
|
||||
auto mt = body_fg->output()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mt);
|
||||
(void)body_while_fg_inputs.insert(body_while_fg_inputs.end(), mt->inputs().begin() + 1, mt->inputs().end());
|
||||
} else {
|
||||
body_while_fg_inputs.push_back(body_fg->output());
|
||||
}
|
||||
if (body_while_fg_inputs.size() - 1 != init_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "The while body's output size should be equal to init_list.size(), but got "
|
||||
<< (body_while_fg_inputs.size() - 1) << " vs " << init_list.size();
|
||||
}
|
||||
if (body_while_fg_inputs.size() < main_while_fg_inputs.size()) {
|
||||
for (size_t i = body_while_fg_inputs.size(); i < main_while_fg_inputs.size(); i++) {
|
||||
auto inp = while_fg->parameters()[i - 1];
|
||||
auto iter = std::find_if(body_with_inputs.begin(), body_with_inputs.end(),
|
||||
[&inp](const NodePtr &no) { return no->get() == inp; });
|
||||
if (iter != body_with_inputs.end()) {
|
||||
auto param_idx = iter - body_with_inputs.begin() - 1;
|
||||
body_while_fg_inputs.push_back(body_fg->parameters()[param_idx]);
|
||||
} else {
|
||||
body_with_inputs.push_back(while_fg_emitter->NewNode(inp));
|
||||
auto p = body_fg->add_parameter();
|
||||
p->set_abstract(inp->abstract());
|
||||
body_while_fg_inputs.push_back(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
auto body_call_fg = body_fg->NewCNode(body_while_fg_inputs);
|
||||
body_call_fg->set_abstract(out_abstract_);
|
||||
body_fg->set_output(body_call_fg);
|
||||
|
||||
auto tb = while_fg_emitter->Emit("Partial", body_with_inputs);
|
||||
auto fb = while_fg_emitter->Emit("Partial", empty_body_fg_with_inputs);
|
||||
auto s = while_fg_emitter->Emit("Switch", {cond, tb, fb});
|
||||
auto cnode = while_fg_emitter->func_graph_->NewCNode({s->get()});
|
||||
cnode->set_abstract(out_abstract_);
|
||||
while_fg->set_output(cnode);
|
||||
|
||||
auto main_cnode = emitter_->func_graph_->NewCNode(main_while_fg_inputs);
|
||||
main_cnode->set_abstract(out_abstract_);
|
||||
return emitter_->NewNode(main_cnode);
|
||||
}
|
||||
|
||||
protected:
|
||||
NodePtr BuildSubgraph(const BlockFunc &func) {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
fg->set_switch_input(std::make_shared<bool>(true));
|
||||
auto e = std::make_unique<Emitter>(fg, emitter_->infer());
|
||||
auto output = func(e.get());
|
||||
if (output.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The block function should not return empty list.";
|
||||
}
|
||||
if (output_num_ == 0) {
|
||||
output_num_ = output.size();
|
||||
} else if (output_num_ != output.size()) {
|
||||
MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
|
||||
<< " vs " << output.size() << ".";
|
||||
}
|
||||
if (output_num_ > 1) {
|
||||
auto mt = e->MakeTuple(output)->get();
|
||||
fg->set_output(mt);
|
||||
SetSequenceNodeElementsUseFlags(mt, std::make_shared<std::vector<bool>>(output_num_, true));
|
||||
} else {
|
||||
fg->set_output(output[0]->get());
|
||||
}
|
||||
if (out_abstract_ == nullptr) {
|
||||
out_abstract_ = fg->output()->abstract();
|
||||
}
|
||||
return emitter_->Value(fg);
|
||||
}
|
||||
|
||||
NodePtrList BuildSubgraphOfPartial(const BlockFunc &func) {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
fg->set_switch_input(std::make_shared<bool>(true));
|
||||
auto sub_emitter = std::make_unique<Emitter>(fg, emitter_->infer());
|
||||
auto output = func(sub_emitter.get());
|
||||
if (output.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The block function should not return empty list.";
|
||||
}
|
||||
if (output_num_ == 0) {
|
||||
output_num_ = output.size();
|
||||
} else if (output_num_ != output.size()) {
|
||||
MS_LOG(EXCEPTION) << "The count of outputs of each block function should be equal, but got " << output_num_
|
||||
<< " vs " << output.size() << ".";
|
||||
}
|
||||
fg->set_output((output_num_ > 1) ? sub_emitter->MakeTuple(output)->get() : output[0]->get());
|
||||
if (out_abstract_ == nullptr) {
|
||||
out_abstract_ = fg->output()->abstract();
|
||||
}
|
||||
if (output_num_ > 1) {
|
||||
SetSequenceNodeElementsUseFlags(fg->output(), std::make_shared<std::vector<bool>>(output_num_, true));
|
||||
}
|
||||
|
||||
// replace the captured inputs to parameter
|
||||
std::function<void(const CNodePtr &)> dfs;
|
||||
std::unordered_set<AnfNodePtr> visited;
|
||||
std::map<AnfNodePtr, ParameterPtr> param_map;
|
||||
NodePtrList fg_with_inputs = {emitter_->Value(fg)};
|
||||
dfs = [&visited, &dfs, &fg, ¶m_map, &fg_with_inputs, this](const CNodePtr &node) {
|
||||
(void)visited.insert(node);
|
||||
for (size_t i = 0; i < node->size(); i++) {
|
||||
auto inp = node->input(i);
|
||||
if (inp->func_graph() == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (inp->func_graph() == fg) {
|
||||
if (inp->isa<CNode>() && visited.count(inp) == 0) {
|
||||
dfs(inp->cast<CNodePtr>());
|
||||
}
|
||||
} else {
|
||||
auto ¶m = param_map[inp];
|
||||
if (param == nullptr) {
|
||||
param = fg->add_parameter();
|
||||
param->set_abstract(inp->abstract());
|
||||
(void)fg_with_inputs.emplace_back(emitter_->NewNode(inp));
|
||||
}
|
||||
node->set_input(i, param);
|
||||
}
|
||||
}
|
||||
};
|
||||
dfs(fg->get_return());
|
||||
return fg_with_inputs;
|
||||
}
|
||||
|
||||
size_t output_num_{0};
|
||||
const Emitter *emitter_;
|
||||
abstract::AbstractBasePtr out_abstract_{nullptr};
|
||||
|
||||
class CppInferWithPartial : public CppInfer {
|
||||
public:
|
||||
void Infer(const NodePtr &node) override {
|
||||
if (IsPrimitiveCNode(node->get(), prim::kPrimPartial) || IsPrimitiveCNode(node->get(), prim::kPrimSwitch)) {
|
||||
return;
|
||||
}
|
||||
CppInfer::Infer(node);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
NodePtr Emitter::Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) const {
|
||||
return CtrlFlowBlock(this).IfThenElse(cond, true_case, false_case);
|
||||
}
|
||||
|
||||
NodePtr Emitter::While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) const {
|
||||
return CtrlFlowBlock(this).While(cond, body, init_list);
|
||||
}
|
||||
|
||||
NodePtr operator+(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Add(lhs, rhs); }
|
||||
NodePtr operator-(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Sub(lhs, rhs); }
|
||||
NodePtr operator*(const NodePtr &lhs, const NodePtr &rhs) { return lhs->emitter()->Mul(lhs, rhs); }
|
||||
|
|
|
@ -35,6 +35,7 @@ class MS_CORE_API Emitter {
|
|||
public:
|
||||
Emitter(const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer, const ScopePtr &scope = nullptr)
|
||||
: func_graph_(func_graph), infer_(infer), scope_(scope) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(infer);
|
||||
}
|
||||
|
||||
|
@ -153,6 +154,7 @@ class MS_CORE_API Emitter {
|
|||
return EmitValue(tensor_ptr);
|
||||
}
|
||||
|
||||
/// \brief get the ExpanderInferPtr
|
||||
ExpanderInferPtr infer() const { return infer_; }
|
||||
|
||||
/// \brief Shape calculation.
|
||||
|
@ -168,6 +170,29 @@ class MS_CORE_API Emitter {
|
|||
NodePtrList ShapeCalc(const NodePtrList &inputs, const ops::ShapeFunc &shape_func, const ops::InferFunc &infer_func,
|
||||
const std::vector<int64_t> &value_depend_indices = {}) const;
|
||||
|
||||
using BlockFunc = std::function<NodePtrList(const Emitter *)>;
|
||||
/// \brief Generate a conditional block.
|
||||
///
|
||||
/// \param[in] cond condition node, it should be a tensor of Bool.
|
||||
/// \param[in] true_case the true branch.
|
||||
/// \param[in] false_case the false branch.
|
||||
/// \return node of tuple or single value, which is depends on the output list of two branches.
|
||||
/// \note The overloaded operators (like a+b) should not be used for captured variables in the true_case/false_case
|
||||
/// functions, use the function argument `Emitter` instead, like `emitter->Add(a, b)`. The output list of two branches
|
||||
/// should match the join rules of control flow.
|
||||
NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) const;
|
||||
|
||||
/// \brief Generate a while-loop block.
|
||||
///
|
||||
/// \param[in] cond condition node, it should be a tensor of Bool.
|
||||
/// \param[in] body the loop body.
|
||||
/// \param[in] init_list the initial variables that would be modified in body.
|
||||
/// \return node of tuple or single value, which is depends on the init_list.
|
||||
/// \note The overloaded operators (like `a+b`) should not be used for captured variables in the body function, use
|
||||
/// the function argument `Emitter` instead, like `emitter->Add(a, b)`. The length and node order of the output list
|
||||
/// of the body function should match init_list.
|
||||
NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) const;
|
||||
|
||||
protected:
|
||||
NodePtr NewNode(const AnfNodePtr &anfnode) const { return std::make_shared<Node>(anfnode, this); }
|
||||
NodePtr CmpOpWithCast(const std::string &op, const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) const {
|
||||
|
@ -180,6 +205,8 @@ class MS_CORE_API Emitter {
|
|||
return Emit(op, {lhs, rhs}, attrs);
|
||||
}
|
||||
|
||||
class CtrlFlowBlock;
|
||||
|
||||
FuncGraphPtr func_graph_;
|
||||
ExpanderInferPtr infer_{nullptr};
|
||||
ScopePtr scope_{nullptr};
|
||||
|
|
|
@ -84,10 +84,8 @@ void CppInfer::Infer(const NodePtr &node) {
|
|||
});
|
||||
AbstractBasePtr result = nullptr;
|
||||
auto found = abstract::GetPrimitiveInferImpl(prim);
|
||||
if (found.has_value()) {
|
||||
auto infer = found.value();
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(infer.IsImplInferShapeAndType(), "There is no infer-abstract implement!");
|
||||
result = infer.InferShapeAndType(nullptr, prim, abs_list);
|
||||
if (found.has_value() && found.value().IsImplInferShapeAndType()) {
|
||||
result = found.value().InferShapeAndType(nullptr, prim, abs_list);
|
||||
} else {
|
||||
auto iter = unreg_infer_map.find(prim);
|
||||
if (iter != unreg_infer_map.end()) {
|
||||
|
|
Loading…
Reference in New Issue