!45016 optimize getitem and add new interface to maintain users
Merge pull request !45016 from DeshiChen/1102_bprop
This commit is contained in:
commit
3400d592ae
|
@ -17,46 +17,189 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "common/graph_kernel/bprop/expander/infer.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace expander {
|
||||
namespace bprop {
|
||||
class BpropExpander {
|
||||
public:
|
||||
BpropExpander(CNodePtrList *outputs, DoutUser *dout_user) : outputs_(outputs), dout_user_(dout_user) {}
|
||||
BpropExpander(CNodePtrList *outputs, DoutUserType *dout_user, UserType *users)
|
||||
: outputs_(outputs), dout_user_(dout_user), users_(users) {}
|
||||
~BpropExpander() = default;
|
||||
|
||||
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilderPtr &ir_builder) {
|
||||
NodePtrList ExtractInputs(const CNodePtr &cnode, const BpropIRBuilder *ir_builder) {
|
||||
NodePtrList nodes;
|
||||
nodes.reserve(cnode->size());
|
||||
(void)std::transform(cnode->inputs().cbegin() + 1, cnode->inputs().cend(), std::back_inserter(nodes),
|
||||
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder.get()); });
|
||||
[ir_builder](const AnfNodePtr &no) { return std::make_shared<Node>(no, ir_builder); });
|
||||
return nodes;
|
||||
}
|
||||
|
||||
bool Run(const CNodePtr &cnode) {
|
||||
auto infer = std::make_shared<CppInfer>();
|
||||
auto name = AnfUtils::GetCNodeName(cnode);
|
||||
auto ir_builder = std::make_shared<BpropIRBuilder>(name, cnode->func_graph(), infer);
|
||||
auto inputs = ExtractInputs(cnode, ir_builder);
|
||||
auto ir_builder = std::make_unique<BpropIRBuilder>(name, cnode->func_graph(), infer);
|
||||
auto inputs = ExtractInputs(cnode, ir_builder.get());
|
||||
auto &attrs = GetCNodePrimitive(cnode)->attrs();
|
||||
return ir_builder->Run(inputs, attrs, outputs_, dout_user_);
|
||||
auto ret = ir_builder->Run(inputs, attrs, outputs_);
|
||||
if (!ret) {
|
||||
return false;
|
||||
}
|
||||
PostProcess(inputs);
|
||||
static bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
|
||||
if (dump_result) {
|
||||
DumpResult(name, inputs);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void PostProcess(const NodePtrList &inputs) const {
|
||||
std::set<AnfNodePtr> visited;
|
||||
// do not visit the inputs again.
|
||||
std::for_each(inputs.cbegin(), inputs.cend(), [&visited](const NodePtr &node) { visited.insert(node->get()); });
|
||||
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
AnfNodePtr dout = inputs.back()->get();
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
// record parameter's and dout's user
|
||||
if (dout_user_ != nullptr) {
|
||||
if (inp == dout) {
|
||||
(void)dout_user_->emplace_back(node, i);
|
||||
}
|
||||
} else { // users_ != nullptr
|
||||
if (inp == dout || inp->isa<Parameter>()) {
|
||||
(*users_)[inp].emplace_back(node, i);
|
||||
}
|
||||
}
|
||||
if (IsPrimitiveCNode(inp, prim::kPrimTupleGetItem)) {
|
||||
auto getitem = inp->cast<CNodePtr>();
|
||||
auto real_input = getitem->input(kIndex1);
|
||||
// record the dout's successor getitem's users
|
||||
if (users_ != nullptr && real_input == dout) {
|
||||
(*users_)[inp].emplace_back(node, i);
|
||||
} else if (real_input->isa<ValueNode>()) {
|
||||
// eliminate redundant getitem
|
||||
auto real_input_value = real_input->cast<ValueNodePtr>()->value();
|
||||
if (real_input_value->isa<ValueSequence>()) {
|
||||
auto item_idx = GetValue<int64_t>(getitem->input(kIndex2)->cast<ValueNodePtr>()->value());
|
||||
auto newnode = NewValueNode((*(real_input_value->cast<ValueSequencePtr>()))[item_idx]);
|
||||
newnode->set_abstract(newnode->value()->ToAbstract());
|
||||
node->set_input(i, newnode);
|
||||
continue; // do not visit the getitem again from this node
|
||||
}
|
||||
}
|
||||
}
|
||||
if (inp->isa<CNode>() && visited.count(inp) == 0) {
|
||||
(void)visited.insert(inp);
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpResult(const std::string &name, const NodePtrList &inputs) const {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||
CNodePtrList newcnodes;
|
||||
for (auto &inp : inputs) {
|
||||
auto p = fg->add_parameter();
|
||||
p->set_abstract(inp->get()->abstract());
|
||||
node_map[inp->get()] = p;
|
||||
}
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
if (node_map.count(node)) {
|
||||
continue;
|
||||
}
|
||||
auto new_node = fg->NewCNode(node->inputs());
|
||||
new_node->CloneCNodeInfo(node);
|
||||
new_node->set_fullname_with_scope(node->fullname_with_scope());
|
||||
node_map[node] = new_node;
|
||||
newcnodes.push_back(new_node);
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
if (inp->isa<CNode>() && node_map.count(inp) == 0) {
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &cnode : newcnodes) {
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
if (node_map.count(cnode->input(i)) != 0) {
|
||||
cnode->set_input(i, node_map[cnode->input(i)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs_->size() == 1) {
|
||||
fg->set_output(node_map[(*outputs_)[0]]);
|
||||
} else {
|
||||
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtrList abs;
|
||||
(void)std::transform(outputs_->cbegin(), outputs_->cend(), std::back_inserter(new_outputs),
|
||||
[&node_map, &abs](const CNodePtr &node) {
|
||||
abs.push_back(node->abstract());
|
||||
return node_map[node];
|
||||
});
|
||||
auto mt = fg->NewCNode(new_outputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
fg->set_output(mt);
|
||||
}
|
||||
DumpIR("bprop/bprop_expander_" + name + ".ir", fg, true);
|
||||
|
||||
if (dout_user_ != nullptr) {
|
||||
for (auto &iter : *dout_user_) {
|
||||
MS_LOG(INFO) << "Dout User: " << iter.first->fullname_with_scope() << " index: " << iter.second;
|
||||
}
|
||||
} else { // users_ != nullptr
|
||||
for (auto &uiter : *users_) {
|
||||
for (auto &iter : uiter.second) {
|
||||
MS_LOG(INFO) << "Node " << uiter.first->ToString() << " user: " << iter.first->fullname_with_scope()
|
||||
<< " index: " << iter.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CNodePtrList *outputs_;
|
||||
expander::bprop::DoutUser *dout_user_;
|
||||
DoutUserType *dout_user_;
|
||||
UserType *users_;
|
||||
};
|
||||
} // namespace bprop
|
||||
} // namespace expander
|
||||
|
||||
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, expander::bprop::DoutUser *dout_user) {
|
||||
// deprecated
|
||||
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(dout_user);
|
||||
expander::bprop::BpropExpander e(outputs, dout_user);
|
||||
expander::bprop::BpropExpander e(outputs, dout_user, nullptr);
|
||||
(void)e.Run(cnode);
|
||||
}
|
||||
|
||||
void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(users);
|
||||
expander::bprop::BpropExpander e(outputs, nullptr, users);
|
||||
(void)e.Run(cnode);
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,11 +16,19 @@
|
|||
#ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
|
||||
#define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
#include "include/common/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, expander::bprop::DoutUser *dout_user);
|
||||
using DoutUserType = std::vector<std::pair<CNodePtr, int>>;
|
||||
// deprecated
|
||||
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, DoutUserType *dout_user);
|
||||
|
||||
using UserType = std::map<AnfNodePtr, std::vector<std::pair<CNodePtr, int>>>;
|
||||
COMMON_EXPORT void BuildBprop(const CNodePtr &cnode, CNodePtrList *outputs, UserType *users);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_BPROP_H_
|
||||
|
|
|
@ -17,13 +17,9 @@
|
|||
#include "common/graph_kernel/bprop/bprop_irbuilder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "include/common/debug/anf_ir_dump.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -40,10 +36,8 @@ int64_t CheckRange(int64_t idx, int64_t dim_size) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, std::vector<CNodePtr> *outputs,
|
||||
DoutUser *dout_user) {
|
||||
bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(outputs);
|
||||
MS_EXCEPTION_IF_NULL(dout_user);
|
||||
if (!BpropIRBuilderFactory::Instance().HasOp(name())) {
|
||||
return false;
|
||||
}
|
||||
|
@ -58,100 +52,9 @@ bool BpropIRBuilder::Run(const NodePtrList &inputs, const DAttr &attrs, std::vec
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return cnode;
|
||||
});
|
||||
FindDoutUsers(*outputs, dout_user);
|
||||
if (common::GetEnv("MS_DEV_DUMP_BPROP") == "on") {
|
||||
DumpResult(*outputs, *dout_user);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void BpropIRBuilder::FindDoutUsers(const std::vector<CNodePtr> &outputs, DoutUser *dout_user) const {
|
||||
std::set<AnfNodePtr> visited;
|
||||
// do not visit the inputs again.
|
||||
std::for_each(inputs_ptr_->cbegin(), inputs_ptr_->cend(),
|
||||
[&visited](const NodePtr &node) { visited.insert(node->get()); });
|
||||
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs.cbegin(), outputs.cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
AnfNodePtr dout = inputs_ptr_->back()->get();
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
if (inp == dout) {
|
||||
(void)dout_user->emplace_back(node, i);
|
||||
}
|
||||
if (inp->isa<CNode>() && visited.count(inp) == 0) {
|
||||
(void)visited.insert(inp);
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void BpropIRBuilder::DumpResult(const std::vector<CNodePtr> &outputs, const DoutUser &dout_user) const {
|
||||
auto fg = std::make_shared<FuncGraph>();
|
||||
std::map<AnfNodePtr, AnfNodePtr> node_map;
|
||||
CNodePtrList newcnodes;
|
||||
for (auto &inp : *inputs_ptr_) {
|
||||
auto p = fg->add_parameter();
|
||||
p->set_abstract(inp->get()->abstract());
|
||||
node_map[inp->get()] = p;
|
||||
}
|
||||
std::queue<CNodePtr> que;
|
||||
std::for_each(outputs.cbegin(), outputs.cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
|
||||
|
||||
while (!que.empty()) {
|
||||
auto node = que.front();
|
||||
que.pop();
|
||||
if (node_map.count(node)) {
|
||||
continue;
|
||||
}
|
||||
auto new_node = fg->NewCNode(node->inputs());
|
||||
new_node->CloneCNodeInfo(node);
|
||||
new_node->set_fullname_with_scope(node->fullname_with_scope());
|
||||
node_map[node] = new_node;
|
||||
newcnodes.push_back(new_node);
|
||||
for (size_t i = 1; i < node->size(); ++i) {
|
||||
const auto &inp = node->input(i);
|
||||
if (inp->isa<CNode>() && node_map.count(inp) == 0) {
|
||||
que.push(inp->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &cnode : newcnodes) {
|
||||
for (size_t i = 1; i < cnode->size(); i++) {
|
||||
if (node_map.count(cnode->input(i)) != 0) {
|
||||
cnode->set_input(i, node_map[cnode->input(i)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (outputs.size() == 1) {
|
||||
fg->set_output(node_map[outputs[0]]);
|
||||
} else {
|
||||
AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
AbstractBasePtrList abs;
|
||||
(void)std::transform(outputs.cbegin(), outputs.cend(), std::back_inserter(new_outputs),
|
||||
[&node_map, &abs](const CNodePtr &node) {
|
||||
abs.push_back(node->abstract());
|
||||
return node_map[node];
|
||||
});
|
||||
auto mt = fg->NewCNode(new_outputs);
|
||||
mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
|
||||
fg->set_output(mt);
|
||||
}
|
||||
|
||||
for (auto &iter : dout_user) {
|
||||
MS_LOG(INFO) << "Dout User: " << iter.first->fullname_with_scope() << " index: " << iter.second;
|
||||
}
|
||||
|
||||
DumpIR("bprop/bprop_expander_" + name() + ".ir", fg, true);
|
||||
}
|
||||
|
||||
ValuePtr BpropIRBuilder::GetAttr(const std::string &attr) const {
|
||||
auto iter = attrs_ptr_->find(attr);
|
||||
if (iter != attrs_ptr_->end()) {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
|
||||
|
@ -30,14 +29,13 @@
|
|||
namespace mindspore {
|
||||
namespace expander {
|
||||
namespace bprop {
|
||||
using DoutUser = std::vector<std::pair<CNodePtr, int>>;
|
||||
class BpropIRBuilder : public Emitter {
|
||||
public:
|
||||
BpropIRBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
|
||||
: Emitter(func_graph, infer), name_(name) {}
|
||||
|
||||
/// \brief Run irbuilder to generate a graph
|
||||
bool Run(const NodePtrList &inputs, const DAttr &attrs, std::vector<CNodePtr> *outputs, DoutUser *dout_user);
|
||||
bool Run(const NodePtrList &inputs, const DAttr &attrs, CNodePtrList *outputs);
|
||||
|
||||
ValuePtr GetAttr(const std::string &attr) const;
|
||||
template <typename S>
|
||||
|
@ -69,11 +67,7 @@ class BpropIRBuilder : public Emitter {
|
|||
// case 3: x[..., 0:3:2, 0::2, :] => StridedSlice(x, {{-3,{0,3,2}}, {-2,{0,LLONG_MAX,2}}})
|
||||
NodePtr StridedSlice(const NodePtr &x, const std::map<int64_t, std::vector<int64_t>> &slices) const;
|
||||
|
||||
void DumpResult(const std::vector<CNodePtr> &outputs, const DoutUser &dout_user) const;
|
||||
|
||||
protected:
|
||||
void FindDoutUsers(const std::vector<CNodePtr> &outputs, DoutUser *dout_user) const;
|
||||
|
||||
std::string name_;
|
||||
const NodePtrList *inputs_ptr_{nullptr};
|
||||
const DAttr *attrs_ptr_{nullptr};
|
||||
|
|
Loading…
Reference in New Issue