!45016 optimize getitem and add new interface to maintain users

Merge pull request !45016 from DeshiChen/1102_bprop
This commit is contained in:
i-robot 2022-11-03 06:48:01 +00:00 committed by Gitee
commit 3400d592ae
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 163 additions and 115 deletions

View File

@ -17,46 +17,189 @@
#include <algorithm>
#include <memory>
#include <queue>
#include <set>
#include <string>
#include "common/graph_kernel/bprop/expander/infer.h"
#include "utils/anf_utils.h"
#include "include/common/debug/anf_ir_dump.h"
namespace mindspore {
namespace expander {
namespace bprop {
class BpropExpander {
public:
BpropExpander(CNodePtrList *outputs, DoutUser *dout_user) : outputs_(outputs), dout_user_(dout_user) {}
BpropExpander(CNodePtrList *outputs, DoutUserType *dout_user, UserType *users)
: outputs_(outputs), dout_user_(dout_user), users_(users) {}
~BpropExpander() = default;
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilderPtr &ir_builder) {
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
NodePtrList nodes;
nodes.reserve(cnode->size());
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes),
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder.get()); });
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); });
return nodes;
}
bool Run(const CNodePtr &cnode) {
auto infer = std::make_shared<CppInfer>();
auto name = AnfUtils::GetCNodeName(cnode);
auto ir_builder = std::make_shared<BpropIRBuilder>(name, cnode->func_graph(), infer);
auto inputs = ExtractInputs(cnode, ir_builder);
auto ir_builder = std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
auto inputs = ExtractInputs(cnode, ir_builder.get());
auto &attrs = GetCNodePrimitive(cnode)->attrs();
return ir_builder->Run(inputs, attrs, outputs_, dout_user_);
auto ret = ir_builder->Run(inputs, attrs, outputs_);
if (!ret) {
return false;
}
PostProcess(inputs);
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
if (dump_result) {
DumpResult(name, inputs);
}
return true;
}
void PostProcess(const NodePtrList &inputs) const {
std::set<AnfNodePtr> visited;
// do not visit the inputs again.
std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); });
std::queue<CNodePtr> que;
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
AnfNodePtr dout = inputs.back()->get();
while (!que.empty()) {
auto node = que.front();
que.pop();
for (size_t i = 1; i < node->size(); ++i) {
const auto &inp = node->input(i);
// record parameter's and dout's user
if (dout_user_ != nullptr) {
if (inp == dout) {
(void)dout_user_->emplace_back(node, i);
}
} else { // users_ != nullptr
if (inp == dout || inp->isa<Parameter>()) {
(*users_)[inp].emplace_back(node, i);
}
}
if (IsPrimitiveCNode(inp, prim::kPrimTupleGetItem)) {
auto getitem = inp->cast<CNodePtr>();
auto real_input = getitem->input(kIndex1);
// record the dout's successor getitem's users
if (users_ != nullptr && real_input == dout) {
(*users_)[inp].emplace_back(node, i);
} else if (real_input->isa<ValueNode>()) {
// eliminate redundant getitem
auto real_input_value = real_input->cast<ValueNodePtr>()->value();
if (real_input_value->isa<ValueSequence>()) {
auto item_idx = GetValue<int64_t>(getitem->input(kIndex2)->cast<ValueNodePtr>()->value());
auto newnode = NewValueNode((*(real_input_value->cast<ValueSequencePtr>()))[item_idx]);
newnode->set_abstract(newnode->value()->ToAbstract());
node->set_input(i, newnode);
continue; // do not visit the getitem again from this node
}
}
}
if (inp->isa<CNode>() && visited.count(inp) == 0) {
(void)visited.insert(inp);
que.push(inp->cast<CNodePtr>());
}
}
}
}
void DumpResult(const std::string &name, const NodePtrList &inputs) const {
auto fg = std::make_shared<FuncGraph>();
std::map<AnfNodePtr, AnfNodePtr> node_map;
CNodePtrList newcnodes;
for (auto &inp : inputs) {
auto p = fg->add_parameter();
p->set_abstract(inp->get()->abstract());
node_map[inp->get()] = p;
}
std::queue<CNodePtr> que;
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
while (!que.empty()) {
auto node = que.front();
que.pop();
if (node_map.count(node)) {
continue;
}
auto new_node = fg->NewCNode(node->inputs());
new_node->CloneCNodeInfo(node);
new_node->set_fullname_with_scope(node->fullname_with_scope());
node_map[node] = new_node;
newcnodes.push_back(new_node);
for (size_t i = 1; i < node->size(); ++i) {
const auto &inp = node->input(i);
if (inp->isa<CNode>() && node_map.count(inp) == 0) {
que.push(inp->cast<CNodePtr>());
}
}
}
for (auto &cnode : newcnodes) {
for (size_t i = 1; i < cnode->size(); i++) {
if (node_map.count(cnode->input(i)) != 0) {
cnode->set_input(i, node_map[cnode->input(i)]);
}
}
}
if (outputs_->size() == 1) {
fg->set_output(node_map[(*outputs_)[0]]);
} else {
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abs;
(void)std::transform(outputs_->cbegin(), outputs_->cend(), std::back_inserter(new_outputs),
[&node_map, &abs](const CNodePtr &node) {
abs.push_back(node->abstract());
return node_map[node];
});
auto mt = fg->NewCNode(new_outputs);
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
fg->set_output(mt);
}
DumpIR("bprop/bprop_expander_" + name + ".ir", fg, true);
if (dout_user_ != nullptr) {
for (auto &iter : *dout_user_) {
MS_LOG(INFO) << "Dout User: " << iter.first->fullname_with_scope() << " index: " << iter.second;
}
} else { // users_ != nullptr
for (auto &uiter : *users_) {
for (auto &iter : uiter.second) {
MS_LOG(INFO) << "Node " << uiter.first->ToString() << " user: " << iter.first->fullname_with_scope()
<< " index: " << iter.second;
}
}
}
}
private:
CNodePtrList *outputs_;
expander::bprop::DoutUser *dout_user_;
DoutUserType *dout_user_;
UserType *users_;
};
} // namespace bprop
} // namespace expander
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, expander::bprop::DoutUser *dout_user) {
// deprecated
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(dout_user);
expander::bprop::BpropExpander e(outputs, dout_user);
expander::bprop::BpropExpander e(outputs, dout_user, nullptr);
(void)e.Run(cnode);
}
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(users);
expander::bprop::BpropExpander e(outputs, nullptr, users);
(void)e.Run(cnode);
}
} // namespace mindspore

View File

@ -16,11 +16,19 @@
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
#include <map>
#include <vector>
#include <utility>
#include "ir/anf.h"
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include "include/common/visible.h"
namespace mindspore {
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, expander::bprop::DoutUser *dout_user);
using DoutUserType = std::vector<std::pair<CNodePtr, int>>;
// deprecated
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_

View File

@ -17,13 +17,9 @@
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
#include <algorithm>
#include <queue>
#include <set>
#include <map>
#include <vector>
#include <limits>
#include "include/common/utils/utils.h"
#include "include/common/debug/anf_ir_dump.h"
#include "utils/ms_context.h"
namespace mindspore {
@ -40,10 +36,8 @@ int64_t CheckRange(int64_t idx, int64_t dim_size) {
}
} // namespace
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, std::vector<CNodePtr> *outputs,
DoutUser *dout_user) {
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
MS_EXCEPTION_IF_NULL(dout_user);
if (!BpropIRBuilderFactory::Instance().HasOp(name())) {
return false;
}
@ -58,100 +52,9 @@ bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, std::vec
MS_EXCEPTION_IF_NULL(cnode);
return cnode;
});
FindDoutUsers(*outputs, dout_user);
if (common::GetEnv("MS_DEV_DUMP_BPROP") == "on") {
DumpResult(*outputs, *dout_user);
}
return true;
}
void BpropIRBuilder::FindDoutUsers(const std::vector<CNodePtr> &outputs, DoutUser *dout_user) const {
std::set<AnfNodePtr> visited;
// do not visit the inputs again.
std::for_each(inputs_ptr_->cbegin(), inputs_ptr_->cend(),
[&visited](const NodePtr &node) { visited.insert(node->get()); });
std::queue<CNodePtr> que;
std::for_each(outputs.cbegin(), outputs.cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
AnfNodePtr dout = inputs_ptr_->back()->get();
while (!que.empty()) {
auto node = que.front();
que.pop();
for (size_t i = 1; i < node->size(); ++i) {
const auto &inp = node->input(i);
if (inp == dout) {
(void)dout_user->emplace_back(node, i);
}
if (inp->isa<CNode>() && visited.count(inp) == 0) {
(void)visited.insert(inp);
que.push(inp->cast<CNodePtr>());
}
}
}
}
void BpropIRBuilder::DumpResult(const std::vector<CNodePtr> &outputs, const DoutUser &dout_user) const {
auto fg = std::make_shared<FuncGraph>();
std::map<AnfNodePtr, AnfNodePtr> node_map;
CNodePtrList newcnodes;
for (auto &inp : *inputs_ptr_) {
auto p = fg->add_parameter();
p->set_abstract(inp->get()->abstract());
node_map[inp->get()] = p;
}
std::queue<CNodePtr> que;
std::for_each(outputs.cbegin(), outputs.cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
while (!que.empty()) {
auto node = que.front();
que.pop();
if (node_map.count(node)) {
continue;
}
auto new_node = fg->NewCNode(node->inputs());
new_node->CloneCNodeInfo(node);
new_node->set_fullname_with_scope(node->fullname_with_scope());
node_map[node] = new_node;
newcnodes.push_back(new_node);
for (size_t i = 1; i < node->size(); ++i) {
const auto &inp = node->input(i);
if (inp->isa<CNode>() && node_map.count(inp) == 0) {
que.push(inp->cast<CNodePtr>());
}
}
}
for (auto &cnode : newcnodes) {
for (size_t i = 1; i < cnode->size(); i++) {
if (node_map.count(cnode->input(i)) != 0) {
cnode->set_input(i, node_map[cnode->input(i)]);
}
}
}
if (outputs.size() == 1) {
fg->set_output(node_map[outputs[0]]);
} else {
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abs;
(void)std::transform(outputs.cbegin(), outputs.cend(), std::back_inserter(new_outputs),
[&node_map, &abs](const CNodePtr &node) {
abs.push_back(node->abstract());
return node_map[node];
});
auto mt = fg->NewCNode(new_outputs);
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
fg->set_output(mt);
}
for (auto &iter : dout_user) {
MS_LOG(INFO) << "Dout User: " << iter.first->fullname_with_scope() << " index: " << iter.second;
}
DumpIR("bprop/bprop_expander_" + name() + ".ir", fg, true);
}
ValuePtr BpropIRBuilder::GetAttr(const std::string &attr) const {
auto iter = attrs_ptr_->find(attr);
if (iter != attrs_ptr_->end()) {

View File

@ -19,7 +19,6 @@
#include <memory>
#include <vector>
#include <string>
#include <utility>
#include <map>
#include <functional>
@ -30,14 +29,13 @@
namespace mindspore {
namespace expander {
namespace bprop {
using DoutUser = std::vector<std::pair<CNodePtr, int>>;
class BpropIRBuilder : public Emitter {
public:
BpropIRBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
: Emitter(func_graph, infer), name_(name) {}
/// \brief Run irbuilder to generate a graph
bool Run(const NodePtrList &inputs, const DAttr &attrs, std::vector<CNodePtr> *outputs, DoutUser *dout_user);
bool Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs);
ValuePtr GetAttr(const std::string &attr) const;
template <typename S>
@ -69,11 +67,7 @@ class BpropIRBuilder : public Emitter {
// case 3: x[..., 0:3:2, 0::2, :] => StridedSlice(x, {{-3,{0,3,2}}, {-2,{0,LLONG_MAX,2}}})
NodePtr StridedSlice(const NodePtr &x, const std::map<int64_t, std::vector<int64_t>> &slices) const;
void DumpResult(const std::vector<CNodePtr> &outputs, const DoutUser &dout_user) const;
protected:
void FindDoutUsers(const std::vector<CNodePtr> &outputs, DoutUser *dout_user) const;
std::string name_;
const NodePtrList *inputs_ptr_{nullptr};
const DAttr *attrs_ptr_{nullptr};