optimize converted list to tuple

This commit is contained in:
limingqi107 2023-02-24 12:48:16 +08:00
parent f115941b24
commit 9108c6f60b
7 changed files with 109 additions and 150 deletions

View File

@ -172,7 +172,7 @@ void AddDynamicShapeAttrPass(const std::shared_ptr<session::KernelGraph> &kernel
PassManagerPtr GetEliminateIllegalDataTypePassManager() {
auto pm = std::make_shared<PassManager>("common_eliminate_illegal_data_type_pm");
pm->AddPass(std::make_shared<ConvertListToTuple>());
pm->AddPass(std::make_shared<ConvertListToTuple>("convert_list_to_tuple"));
pm->AddPass(std::make_shared<EliminateFuncDataType>());
return pm;
}

View File

@ -22,89 +22,30 @@
namespace mindspore {
namespace opt {
static const std::map<std::string, std::string> kOpListToTupleNames = {{prim::kMakeListNew, prim::kMakeTuple},
{prim::kListGetItem, prim::kTupleGetItem},
{prim::kListSetItem, prim::kTupleSetItem}};
static const size_t kMaxRecursiveDepth = 6;
const AnfNodePtr ConvertListToTuple::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
// Value list --> Value tuple.
if (node->isa<ValueNode>()) {
auto value_node = node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
bool need_convert = false;
auto convert_value = ConvertValueSequenceToValueTuple(value_node->value(), &need_convert);
if (need_convert) {
return std::make_shared<ValueNode>(convert_value);
}
return nullptr;
}
if (!node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// List name --> tuple name.
auto old_full_name = cnode->fullname_with_scope();
auto old_name = common::AnfAlgo::GetCNodeName(cnode);
auto iter = kOpListToTupleNames.find(old_name);
if (iter != kOpListToTupleNames.end()) {
auto primitive = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(primitive);
primitive->set_name(iter->second);
// Reset full scope name.
cnode->set_fullname_with_scope("");
common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
MS_LOG(DEBUG) << "Rename op from " << old_name << " to " << iter->second << " for op " << old_full_name << " to "
<< cnode->fullname_with_scope();
}
// List abstract --> tuple abstract.
auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract());
if (new_abs != nullptr) {
node->set_abstract(new_abs);
common::AnfAlgo::SetNodeAttr(kAttrAbstractAdaptationProcessed, MakeValue(true), cnode);
MS_LOG(DEBUG) << "Convert sequence abstract to tuple abstract for op " << old_full_name << ", new op name "
<< cnode->fullname_with_scope();
}
return nullptr;
}
// ValueSequence --> ValueTuple.
ValuePtr ConvertListToTuple::ConvertValueSequenceToValueTuple(const ValuePtr &value, bool *need_convert,
size_t depth) const {
MS_EXCEPTION_IF_NULL(need_convert);
MS_EXCEPTION_IF_NULL(value);
if (depth > kMaxRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxRecursiveDepth << " levels.";
}
if (value->isa<ValueSequence>()) {
std::vector<ValuePtr> elements;
auto value_seq = value->cast<ValueSequencePtr>();
(void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(elements),
[&](const ValuePtr &value) -> ValuePtr {
bool is_convert = false;
auto convert_value = ConvertValueSequenceToValueTuple(value, &is_convert, depth + 1);
*need_convert |= is_convert;
return convert_value;
});
*need_convert |= value->isa<ValueList>();
if (*need_convert) {
return std::make_shared<ValueTuple>(elements);
bool ConvertListToTuple::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
for (auto node : node_list) {
MS_EXCEPTION_IF_NULL(node);
// List abstract --> tuple abstract.
auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract());
if (new_abs != nullptr) {
node->set_abstract(new_abs);
if (node->isa<CNode>()) {
common::AnfAlgo::SetNodeAttr(kAttrAbstractAdaptationProcessed, MakeValue(true), node);
}
MS_LOG(INFO) << "Convert sequence abstract to tuple abstract for op:" << node->fullname_with_scope()
<< ",debug name:" << node->DebugString();
}
}
return value;
return true;
}
// AbstractSequence --> AbstractTuple.
AbstractBasePtr ConvertListToTuple::ConvertSequenceAbsToTupleAbs(const AbstractBasePtr &abs, size_t depth) const {
if (abs == nullptr) {
if (abs == nullptr || !abs->isa<abstract::AbstractSequence>()) {
return nullptr;
}
@ -113,48 +54,45 @@ AbstractBasePtr ConvertListToTuple::ConvertSequenceAbsToTupleAbs(const AbstractB
}
auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
if (abs_seq != nullptr) {
// Dynamic length sequence convert by the dynamic abs.
if (abs_seq->dynamic_len() && abs_seq->isa<abstract::AbstractList>()) {
auto converted_dynamic_abs_tuple =
std::make_shared<abstract::AbstractTuple>(abs_seq->elements(), abs_seq->sequence_nodes());
converted_dynamic_abs_tuple->set_dynamic_len(true);
converted_dynamic_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs());
return converted_dynamic_abs_tuple;
}
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertSequenceAbsToTupleAbs(element, depth + 1);
if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element);
}
}
if (changed_elements.empty()) {
if (abs->isa<abstract::AbstractTuple>()) {
// If no elements changed and it is an AbstractTuple, do not convert.
return nullptr;
}
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
return std::make_shared<abstract::AbstractTuple>(seq_elements);
}
// Always make new AbstractTuple when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
}
}
return std::make_shared<abstract::AbstractTuple>(std::move(elements));
MS_EXCEPTION_IF_NULL(abs_seq);
// Dynamic length sequence convert by the dynamic abs.
if (abs_seq->dynamic_len() && abs_seq->isa<abstract::AbstractList>()) {
auto converted_dynamic_abs_tuple =
std::make_shared<abstract::AbstractTuple>(abs_seq->elements(), abs_seq->sequence_nodes());
converted_dynamic_abs_tuple->set_dynamic_len(true);
converted_dynamic_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs());
return converted_dynamic_abs_tuple;
}
return nullptr;
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertSequenceAbsToTupleAbs(element, depth + 1);
if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element);
}
}
if (changed_elements.empty()) {
if (abs->isa<abstract::AbstractTuple>()) {
// If no elements changed and it is an AbstractTuple, do not convert.
return nullptr;
}
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
return std::make_shared<abstract::AbstractTuple>(seq_elements);
}
// Always make new AbstractTuple when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
}
}
return std::make_shared<abstract::AbstractTuple>(std::move(elements));
}
} // namespace opt
} // namespace mindspore

View File

@ -23,15 +23,13 @@
namespace mindspore {
namespace opt {
class BACKEND_EXPORT ConvertListToTuple : public PatternProcessPass {
class BACKEND_EXPORT ConvertListToTuple : public Pass {
public:
explicit ConvertListToTuple(bool multigraph = true) : PatternProcessPass("convert_list_to_tuple", multigraph) {}
explicit ConvertListToTuple(const std::string &name) : Pass("convert_list_to_tuple") {}
~ConvertListToTuple() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
bool Run(const FuncGraphPtr &graph) override;
private:
// ValueSequence --> ValueTuple.
ValuePtr ConvertValueSequenceToValueTuple(const ValuePtr &value, bool *need_convert, size_t depth = 0) const;
// AbstractSequence --> AbstractTuple.
AbstractBasePtr ConvertSequenceAbsToTupleAbs(const AbstractBasePtr &abs, size_t depth = 0) const;
};

View File

@ -994,7 +994,7 @@ void KernelGraphMgr::SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
auto input = common::AnfAlgo::GetInputNode(make_tuple, i);
auto node_abs = input->abstract();
MS_EXCEPTION_IF_NULL(node_abs);
if (node_abs->isa<abstract::AbstractTuple>()) {
if (node_abs->isa<abstract::AbstractSequence>()) {
MS_EXCEPTION_IF_CHECK_FAIL(
i == 0, "Input index: " + std::to_string(i) + " is a make tuple, input node: " + input->DebugString());
MS_LOG(DEBUG) << "Flatten the make tuple, input node: " << input->DebugString()

View File

@ -398,31 +398,6 @@ void MindRTBackendBase::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
}
}
namespace {
void ExchangeRealTupleGetItem(const FuncGraphPtr &root_graph) {
MS_EXCEPTION_IF_NULL(root_graph);
MS_EXCEPTION_IF_NULL(root_graph->manager());
FuncGraphSet graphs = root_graph->manager()->func_graphs();
for (const auto &graph : graphs) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = TopoSort(graph->get_return());
for (const auto &node : nodes) {
if (node == nullptr || (!node->isa<CNode>())) {
continue;
}
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kTupleGetItem &&
cnode->inputs().size() == kTupleGetItemInputSize &&
(!cnode->input(kInputNodeOutputIndexInTupleGetItem)->isa<ValueNode>())) {
cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(prim::kRealTupleGetItem)));
MS_LOG(INFO) << "Exchange tuple get item to real tuple get item for node:" << cnode->DebugString();
}
}
}
}
} // namespace
const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(graph_compiler_);
MS_EXCEPTION_IF_NULL(func_graph);
@ -431,10 +406,9 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);
UnifyMindIR(root_graph);
root_graph_ = root_graph;
ExchangeRealTupleGetItem(root_graph);
// Register a summary callback function, which is called in the final stages of summary.
graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
@ -487,6 +461,49 @@ const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph
return actor_info;
}
void MindRTBackendBase::UnifyMindIR(const FuncGraphPtr &root_graph) {
MS_EXCEPTION_IF_NULL(root_graph);
MS_EXCEPTION_IF_NULL(root_graph->manager());
const std::map<std::string, std::string> kOpListToTupleNames = {{prim::kMakeListNew, prim::kMakeTuple},
{prim::kListGetItem, prim::kTupleGetItem},
{prim::kListSetItem, prim::kTupleSetItem}};
FuncGraphSet graphs = root_graph->manager()->func_graphs();
for (const auto &graph : graphs) {
MS_EXCEPTION_IF_NULL(graph);
auto nodes = TopoSort(graph->get_return());
for (const auto &node : nodes) {
if (node == nullptr || (!node->isa<CNode>())) {
continue;
}
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
// List name --> tuple name.
auto iter = kOpListToTupleNames.find(common::AnfAlgo::GetCNodeName(cnode));
if (iter != kOpListToTupleNames.end()) {
common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(iter->second)));
// Reset full scope name.
cnode->set_fullname_with_scope("");
MS_LOG(INFO) << "Rename op from " << iter->first << " to " << iter->second << " for op "
<< cnode->fullname_with_scope() << ", debug name:" << cnode->DebugString();
}
// TupleGetItem --> RealTupleGetItem.
if (common::AnfAlgo::GetCNodeName(cnode) == prim::kTupleGetItem &&
cnode->inputs().size() == kTupleGetItemInputSize &&
(!cnode->input(kInputNodeOutputIndexInTupleGetItem)->isa<ValueNode>())) {
common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(prim::kRealTupleGetItem)));
// Reset full scope name.
cnode->set_fullname_with_scope("");
MS_LOG(INFO) << "Rename op from TupleGetItem to RealTupleGetItem for op " << cnode->fullname_with_scope()
<< ", debug name:" << cnode->DebugString();
}
}
}
}
void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
auto root_graph = WrapPrimitives(func_graph);
MS_EXCEPTION_IF_NULL(root_graph);

View File

@ -98,6 +98,9 @@ class BACKEND_EXPORT MindRTBackendBase : public Backend {
const VectorRef &args, VectorRef *outputs) {}
protected:
// Convert the nodes which are not supported in the backend.
void UnifyMindIR(const FuncGraphPtr &func_graph);
// The parameter func_graph is a graph, it can be either a root graph or a sub graph,
// The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_.
void CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode);

View File

@ -238,6 +238,9 @@ void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *devic
<< " addr:" << address;
address->set_from_persistent_mem(true);
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
} else {
MS_LOG(INFO) << "No device address for value node:" << value_node->fullname_with_scope()
<< ", debug name:" << common::AnfAlgo::GetNodeDebugString(value_node);
}
}
}